feat: add useful callbacks for GMLVQ
omega trace normalization and matrix profile visualization
This commit is contained in:
		| @@ -1,12 +1,15 @@ | ||||
| import logging | ||||
| import warnings | ||||
| from enum import Enum | ||||
| from typing import Optional, Type | ||||
|  | ||||
| import matplotlib.pyplot as plt | ||||
| import numpy as np | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| import torchmetrics | ||||
| from matplotlib import pyplot as plt | ||||
| from prototorch.models.architectures.base import BaseYArchitecture, Steps | ||||
| from prototorch.models.architectures.comparison import OmegaComparisonMixin | ||||
| from prototorch.models.library.gmlvq import GMLVQ | ||||
| from prototorch.models.vis import Vis2DAbstract | ||||
| from prototorch.utils.utils import mesh2d | ||||
| @@ -213,7 +216,7 @@ class PlotLambdaMatrixToTensorboard(pl.Callback): | ||||
|         # add to tensorboard | ||||
|         if isinstance(trainer.logger, TensorBoardLogger): | ||||
|             trainer.logger.experiment.add_figure( | ||||
|                 f"lambda_matrix", | ||||
|                 "lambda_matrix", | ||||
|                 self.fig, | ||||
|                 trainer.global_step, | ||||
|             ) | ||||
| @@ -221,3 +224,84 @@ class PlotLambdaMatrixToTensorboard(pl.Callback): | ||||
|             warnings.warn( | ||||
|                 f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead." | ||||
|             ) | ||||
|  | ||||
|  | ||||
| class Profiles(Enum): | ||||
|     ''' | ||||
|     Available Profiles | ||||
|     ''' | ||||
|     RELEVANCE = 'relevance' | ||||
|     INFLUENCE = 'influence' | ||||
|  | ||||
|     def __str__(self): | ||||
|         return str(self.value) | ||||
|  | ||||
|  | ||||
| class PlotMatrixProfiles(pl.Callback): | ||||
|  | ||||
|     def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None: | ||||
|         super().__init__() | ||||
|         self.cmap = cmap | ||||
|         self.profile = profile | ||||
|  | ||||
|     def on_train_start(self, trainer, pl_module: GMLVQ): | ||||
|         ''' | ||||
|         Plot initial profile. | ||||
|         ''' | ||||
|         self._plot_profile(trainer, pl_module) | ||||
|  | ||||
|     def on_train_epoch_end(self, trainer, pl_module: GMLVQ): | ||||
|         ''' | ||||
|         Plot after every epoch. | ||||
|         ''' | ||||
|         self._plot_profile(trainer, pl_module) | ||||
|  | ||||
|     def _plot_profile(self, trainer, pl_module: GMLVQ): | ||||
|  | ||||
|         fig, ax = plt.subplots(1, 1) | ||||
|  | ||||
|         # plot lambda matrix | ||||
|         l_matrix = torch.abs(pl_module.lambda_matrix) | ||||
|  | ||||
|         if self.profile == Profiles.RELEVANCE: | ||||
|             profile_value = l_matrix.diag() | ||||
|         elif self.profile == Profiles.INFLUENCE: | ||||
|             profile_value = l_matrix.sum(0) | ||||
|  | ||||
|         # plot lambda matrix | ||||
|         ax.plot(profile_value.detach().numpy()) | ||||
|  | ||||
|         # add title | ||||
|         ax.set_title(f'{self.profile} profile') | ||||
|  | ||||
|         # add to tensorboard | ||||
|         if isinstance(trainer.logger, TensorBoardLogger): | ||||
|             trainer.logger.experiment.add_figure( | ||||
|                 f"{self.profile}_matrix", | ||||
|                 fig, | ||||
|                 trainer.global_step, | ||||
|             ) | ||||
|         else: | ||||
|             class_name = self.__class__.__name__ | ||||
|             logger_name = trainer.logger.__class__.__name__ | ||||
|             warnings.warn( | ||||
|                 f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead." | ||||
|             ) | ||||
|  | ||||
|  | ||||
| class OmegaTraceNormalization(pl.Callback): | ||||
|     ''' | ||||
|     Trace normalization of the Omega Matrix. | ||||
|     ''' | ||||
|     __epsilon = torch.finfo(torch.float32).eps | ||||
|  | ||||
|     def on_train_epoch_end(self, trainer: "pl.Trainer", | ||||
|                            pl_module: OmegaComparisonMixin) -> None: | ||||
|  | ||||
|         omega = pl_module.parameter_omega | ||||
|         denominator = torch.sqrt(torch.trace(omega.T @ omega)) | ||||
|         logging.debug( | ||||
|             "Apply Omega Trace Normalization: demoninator=%f", | ||||
|             denominator.item(), | ||||
|         ) | ||||
|         pl_module.parameter_omega = omega / (denominator + self.__epsilon) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user