feat: add useful callbacks for GMLVQ
omega trace normalization and matrix profile visualization
This commit is contained in:
parent
ba50dfba50
commit
365e0fb931
@ -1,12 +1,15 @@
|
||||
import logging
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Optional, Type
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from matplotlib import pyplot as plt
|
||||
from prototorch.models.architectures.base import BaseYArchitecture, Steps
|
||||
from prototorch.models.architectures.comparison import OmegaComparisonMixin
|
||||
from prototorch.models.library.gmlvq import GMLVQ
|
||||
from prototorch.models.vis import Vis2DAbstract
|
||||
from prototorch.utils.utils import mesh2d
|
||||
@ -213,7 +216,7 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
|
||||
# add to tensorboard
|
||||
if isinstance(trainer.logger, TensorBoardLogger):
|
||||
trainer.logger.experiment.add_figure(
|
||||
f"lambda_matrix",
|
||||
"lambda_matrix",
|
||||
self.fig,
|
||||
trainer.global_step,
|
||||
)
|
||||
@ -221,3 +224,84 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
|
||||
warnings.warn(
|
||||
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
|
||||
)
|
||||
|
||||
|
||||
class Profiles(Enum):
|
||||
'''
|
||||
Available Profiles
|
||||
'''
|
||||
RELEVANCE = 'relevance'
|
||||
INFLUENCE = 'influence'
|
||||
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class PlotMatrixProfiles(pl.Callback):
|
||||
|
||||
def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None:
|
||||
super().__init__()
|
||||
self.cmap = cmap
|
||||
self.profile = profile
|
||||
|
||||
def on_train_start(self, trainer, pl_module: GMLVQ):
|
||||
'''
|
||||
Plot initial profile.
|
||||
'''
|
||||
self._plot_profile(trainer, pl_module)
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
|
||||
'''
|
||||
Plot after every epoch.
|
||||
'''
|
||||
self._plot_profile(trainer, pl_module)
|
||||
|
||||
def _plot_profile(self, trainer, pl_module: GMLVQ):
|
||||
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
|
||||
# plot lambda matrix
|
||||
l_matrix = torch.abs(pl_module.lambda_matrix)
|
||||
|
||||
if self.profile == Profiles.RELEVANCE:
|
||||
profile_value = l_matrix.diag()
|
||||
elif self.profile == Profiles.INFLUENCE:
|
||||
profile_value = l_matrix.sum(0)
|
||||
|
||||
# plot lambda matrix
|
||||
ax.plot(profile_value.detach().numpy())
|
||||
|
||||
# add title
|
||||
ax.set_title(f'{self.profile} profile')
|
||||
|
||||
# add to tensorboard
|
||||
if isinstance(trainer.logger, TensorBoardLogger):
|
||||
trainer.logger.experiment.add_figure(
|
||||
f"{self.profile}_matrix",
|
||||
fig,
|
||||
trainer.global_step,
|
||||
)
|
||||
else:
|
||||
class_name = self.__class__.__name__
|
||||
logger_name = trainer.logger.__class__.__name__
|
||||
warnings.warn(
|
||||
f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead."
|
||||
)
|
||||
|
||||
|
||||
class OmegaTraceNormalization(pl.Callback):
|
||||
'''
|
||||
Trace normalization of the Omega Matrix.
|
||||
'''
|
||||
__epsilon = torch.finfo(torch.float32).eps
|
||||
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer",
|
||||
pl_module: OmegaComparisonMixin) -> None:
|
||||
|
||||
omega = pl_module.parameter_omega
|
||||
denominator = torch.sqrt(torch.trace(omega.T @ omega))
|
||||
logging.debug(
|
||||
"Apply Omega Trace Normalization: demoninator=%f",
|
||||
denominator.item(),
|
||||
)
|
||||
pl_module.parameter_omega = omega / (denominator + self.__epsilon)
|
||||
|
Loading…
Reference in New Issue
Block a user