feat: add useful callbacks for GMLVQ

omega trace normalization and matrix profile visualization
This commit is contained in:
Alexander Engelsberger 2022-09-21 13:23:43 +02:00
parent ba50dfba50
commit 365e0fb931
No known key found for this signature in database
GPG Key ID: DE8669706B6AC2E7

View File

@ -1,12 +1,15 @@
import logging
import warnings import warnings
from enum import Enum
from typing import Optional, Type from typing import Optional, Type
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from matplotlib import pyplot as plt
from prototorch.models.architectures.base import BaseYArchitecture, Steps from prototorch.models.architectures.base import BaseYArchitecture, Steps
from prototorch.models.architectures.comparison import OmegaComparisonMixin
from prototorch.models.library.gmlvq import GMLVQ from prototorch.models.library.gmlvq import GMLVQ
from prototorch.models.vis import Vis2DAbstract from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d from prototorch.utils.utils import mesh2d
@ -213,7 +216,7 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
# add to tensorboard # add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger): if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure( trainer.logger.experiment.add_figure(
f"lambda_matrix", "lambda_matrix",
self.fig, self.fig,
trainer.global_step, trainer.global_step,
) )
@ -221,3 +224,84 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
warnings.warn( warnings.warn(
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead." f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
) )
class Profiles(Enum):
'''
Available Profiles
'''
RELEVANCE = 'relevance'
INFLUENCE = 'influence'
def __str__(self):
return str(self.value)
class PlotMatrixProfiles(pl.Callback):
def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None:
super().__init__()
self.cmap = cmap
self.profile = profile
def on_train_start(self, trainer, pl_module: GMLVQ):
'''
Plot initial profile.
'''
self._plot_profile(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
'''
Plot after every epoch.
'''
self._plot_profile(trainer, pl_module)
def _plot_profile(self, trainer, pl_module: GMLVQ):
fig, ax = plt.subplots(1, 1)
# plot lambda matrix
l_matrix = torch.abs(pl_module.lambda_matrix)
if self.profile == Profiles.RELEVANCE:
profile_value = l_matrix.diag()
elif self.profile == Profiles.INFLUENCE:
profile_value = l_matrix.sum(0)
# plot lambda matrix
ax.plot(profile_value.detach().numpy())
# add title
ax.set_title(f'{self.profile} profile')
# add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure(
f"{self.profile}_matrix",
fig,
trainer.global_step,
)
else:
class_name = self.__class__.__name__
logger_name = trainer.logger.__class__.__name__
warnings.warn(
f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead."
)
class OmegaTraceNormalization(pl.Callback):
'''
Trace normalization of the Omega Matrix.
'''
__epsilon = torch.finfo(torch.float32).eps
def on_train_epoch_end(self, trainer: "pl.Trainer",
pl_module: OmegaComparisonMixin) -> None:
omega = pl_module.parameter_omega
denominator = torch.sqrt(torch.trace(omega.T @ omega))
logging.debug(
"Apply Omega Trace Normalization: demoninator=%f",
denominator.item(),
)
pl_module.parameter_omega = omega / (denominator + self.__epsilon)