feat: add useful callbacks for GMLVQ
omega trace normalization and matrix profile visualization
This commit is contained in:
parent
ba50dfba50
commit
365e0fb931
@ -1,12 +1,15 @@
|
|||||||
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
from prototorch.models.architectures.base import BaseYArchitecture, Steps
|
from prototorch.models.architectures.base import BaseYArchitecture, Steps
|
||||||
|
from prototorch.models.architectures.comparison import OmegaComparisonMixin
|
||||||
from prototorch.models.library.gmlvq import GMLVQ
|
from prototorch.models.library.gmlvq import GMLVQ
|
||||||
from prototorch.models.vis import Vis2DAbstract
|
from prototorch.models.vis import Vis2DAbstract
|
||||||
from prototorch.utils.utils import mesh2d
|
from prototorch.utils.utils import mesh2d
|
||||||
@ -213,7 +216,7 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
|
|||||||
# add to tensorboard
|
# add to tensorboard
|
||||||
if isinstance(trainer.logger, TensorBoardLogger):
|
if isinstance(trainer.logger, TensorBoardLogger):
|
||||||
trainer.logger.experiment.add_figure(
|
trainer.logger.experiment.add_figure(
|
||||||
f"lambda_matrix",
|
"lambda_matrix",
|
||||||
self.fig,
|
self.fig,
|
||||||
trainer.global_step,
|
trainer.global_step,
|
||||||
)
|
)
|
||||||
@ -221,3 +224,84 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
|
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Profiles(Enum):
|
||||||
|
'''
|
||||||
|
Available Profiles
|
||||||
|
'''
|
||||||
|
RELEVANCE = 'relevance'
|
||||||
|
INFLUENCE = 'influence'
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
class PlotMatrixProfiles(pl.Callback):
|
||||||
|
|
||||||
|
def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.cmap = cmap
|
||||||
|
self.profile = profile
|
||||||
|
|
||||||
|
def on_train_start(self, trainer, pl_module: GMLVQ):
|
||||||
|
'''
|
||||||
|
Plot initial profile.
|
||||||
|
'''
|
||||||
|
self._plot_profile(trainer, pl_module)
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
|
||||||
|
'''
|
||||||
|
Plot after every epoch.
|
||||||
|
'''
|
||||||
|
self._plot_profile(trainer, pl_module)
|
||||||
|
|
||||||
|
def _plot_profile(self, trainer, pl_module: GMLVQ):
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(1, 1)
|
||||||
|
|
||||||
|
# plot lambda matrix
|
||||||
|
l_matrix = torch.abs(pl_module.lambda_matrix)
|
||||||
|
|
||||||
|
if self.profile == Profiles.RELEVANCE:
|
||||||
|
profile_value = l_matrix.diag()
|
||||||
|
elif self.profile == Profiles.INFLUENCE:
|
||||||
|
profile_value = l_matrix.sum(0)
|
||||||
|
|
||||||
|
# plot lambda matrix
|
||||||
|
ax.plot(profile_value.detach().numpy())
|
||||||
|
|
||||||
|
# add title
|
||||||
|
ax.set_title(f'{self.profile} profile')
|
||||||
|
|
||||||
|
# add to tensorboard
|
||||||
|
if isinstance(trainer.logger, TensorBoardLogger):
|
||||||
|
trainer.logger.experiment.add_figure(
|
||||||
|
f"{self.profile}_matrix",
|
||||||
|
fig,
|
||||||
|
trainer.global_step,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
class_name = self.__class__.__name__
|
||||||
|
logger_name = trainer.logger.__class__.__name__
|
||||||
|
warnings.warn(
|
||||||
|
f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OmegaTraceNormalization(pl.Callback):
|
||||||
|
'''
|
||||||
|
Trace normalization of the Omega Matrix.
|
||||||
|
'''
|
||||||
|
__epsilon = torch.finfo(torch.float32).eps
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, trainer: "pl.Trainer",
|
||||||
|
pl_module: OmegaComparisonMixin) -> None:
|
||||||
|
|
||||||
|
omega = pl_module.parameter_omega
|
||||||
|
denominator = torch.sqrt(torch.trace(omega.T @ omega))
|
||||||
|
logging.debug(
|
||||||
|
"Apply Omega Trace Normalization: demoninator=%f",
|
||||||
|
denominator.item(),
|
||||||
|
)
|
||||||
|
pl_module.parameter_omega = omega / (denominator + self.__epsilon)
|
||||||
|
Loading…
Reference in New Issue
Block a user