diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index 33b8a32..cea73e8 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -1,12 +1,15 @@ +import logging import warnings +from enum import Enum from typing import Optional, Type +import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch import torchmetrics -from matplotlib import pyplot as plt from prototorch.models.architectures.base import BaseYArchitecture, Steps +from prototorch.models.architectures.comparison import OmegaComparisonMixin from prototorch.models.library.gmlvq import GMLVQ from prototorch.models.vis import Vis2DAbstract from prototorch.utils.utils import mesh2d @@ -213,7 +216,7 @@ class PlotLambdaMatrixToTensorboard(pl.Callback): # add to tensorboard if isinstance(trainer.logger, TensorBoardLogger): trainer.logger.experiment.add_figure( - f"lambda_matrix", + "lambda_matrix", self.fig, trainer.global_step, ) @@ -221,3 +224,84 @@ class PlotLambdaMatrixToTensorboard(pl.Callback): warnings.warn( f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead." ) + + +class Profiles(Enum): + ''' + Available Profiles + ''' + RELEVANCE = 'relevance' + INFLUENCE = 'influence' + + def __str__(self): + return str(self.value) + + +class PlotMatrixProfiles(pl.Callback): + + def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None: + super().__init__() + self.cmap = cmap + self.profile = profile + + def on_train_start(self, trainer, pl_module: GMLVQ): + ''' + Plot initial profile. + ''' + self._plot_profile(trainer, pl_module) + + def on_train_epoch_end(self, trainer, pl_module: GMLVQ): + ''' + Plot after every epoch. + ''' + self._plot_profile(trainer, pl_module) + + def _plot_profile(self, trainer, pl_module: GMLVQ): + + fig, ax = plt.subplots(1, 1) + + # plot lambda matrix + l_matrix = torch.abs(pl_module.lambda_matrix) + + if self.profile == Profiles.RELEVANCE: + profile_value = l_matrix.diag() + elif self.profile == Profiles.INFLUENCE: + profile_value = l_matrix.sum(0) + + # plot lambda matrix + ax.plot(profile_value.detach().numpy()) + + # add title + ax.set_title(f'{self.profile} profile') + + # add to tensorboard + if isinstance(trainer.logger, TensorBoardLogger): + trainer.logger.experiment.add_figure( + f"{self.profile}_matrix", + fig, + trainer.global_step, + ) + else: + class_name = self.__class__.__name__ + logger_name = trainer.logger.__class__.__name__ + warnings.warn( + f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead." + ) + + +class OmegaTraceNormalization(pl.Callback): + ''' + Trace normalization of the Omega Matrix. + ''' + __epsilon = torch.finfo(torch.float32).eps + + def on_train_epoch_end(self, trainer: "pl.Trainer", + pl_module: OmegaComparisonMixin) -> None: + + omega = pl_module.parameter_omega + denominator = torch.sqrt(torch.trace(omega.T @ omega)) + logging.debug( + "Apply Omega Trace Normalization: demoninator=%f", + denominator.item(), + ) + pl_module.parameter_omega = omega / (denominator + self.__epsilon)