From c6f718a1d4c010734bef187d255df189a7616c4f Mon Sep 17 00:00:00 2001 From: julius Date: Tue, 7 Nov 2023 16:44:13 +0100 Subject: [PATCH] GMLMLVQ: allow for 2 or more omega layers --- src/prototorch/models/glvq.py | 88 ++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 4de706b..7709e93 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -1,14 +1,17 @@ """Models based on the GLVQ framework.""" +from typing import LiteralString + import torch +from numpy.typing import NDArray from prototorch.core.competitions import wtac from prototorch.core.distances import ( + ML_omega_distance, lomega_distance, omega_distance, - ML_omega_distance, squared_euclidean_distance, ) -from prototorch.core.initializers import (EyeLinearTransformInitializer, LLTI) +from prototorch.core.initializers import LLTI, EyeLinearTransformInitializer from prototorch.core.losses import ( GLVQLoss, lvq1_loss, @@ -16,7 +19,7 @@ from prototorch.core.losses import ( ) from prototorch.core.transforms import LinearTransform from prototorch.nn.wrappers import LambdaLayer, LossLayer -from torch.nn.parameter import Parameter +from torch.nn import Parameter, ParameterList from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel from .extras import ltangent_distance, orthogonalization @@ -46,26 +49,28 @@ class GLVQ(SupervisedPrototypeModel): def initialize_prototype_win_ratios(self): self.register_buffer( - "prototype_win_ratios", - torch.zeros(self.num_prototypes, device=self.device)) + "prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device) + ) def on_train_epoch_start(self): self.initialize_prototype_win_ratios() def log_prototype_win_ratios(self, distances): batch_size = len(distances) - prototype_wc = torch.zeros(self.num_prototypes, - dtype=torch.long, - device=self.device) - wi, wc = torch.unique(distances.min(dim=-1).indices, - sorted=True, - return_counts=True) + prototype_wc = torch.zeros( + self.num_prototypes, dtype=torch.long, device=self.device + ) + wi, wc = torch.unique( + distances.min(dim=-1).indices, sorted=True, return_counts=True + ) prototype_wc[wi] = wc prototype_wr = prototype_wc / batch_size - self.prototype_win_ratios = torch.vstack([ - self.prototype_win_ratios, - prototype_wr, - ]) + self.prototype_win_ratios = torch.vstack( + [ + self.prototype_win_ratios, + prototype_wr, + ] + ) def shared_step(self, batch, batch_idx): x, y = batch @@ -110,11 +115,9 @@ class SiameseGLVQ(GLVQ): """ - def __init__(self, - hparams, - backbone=torch.nn.Identity(), - both_path_gradients=False, - **kwargs): + def __init__( + self, hparams, backbone=torch.nn.Identity(), both_path_gradients=False, **kwargs + ): distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance) super().__init__(hparams, distance_fn=distance_fn, **kwargs) self.backbone = backbone @@ -176,6 +179,7 @@ class GRLVQ(SiameseGLVQ): TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise. """ + _relevances: torch.Tensor def __init__(self, hparams, **kwargs): @@ -186,8 +190,7 @@ class GRLVQ(SiameseGLVQ): self.register_parameter("_relevances", Parameter(relevances)) # Override the backbone - self.backbone = LambdaLayer(self._apply_relevances, - name="relevance scaling") + self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling") def _apply_relevances(self, x): return x @ torch.diag(self._relevances) @@ -211,8 +214,9 @@ class SiameseGMLVQ(SiameseGLVQ): super().__init__(hparams, **kwargs) # Override the backbone - omega_initializer = kwargs.get("omega_initializer", - EyeLinearTransformInitializer()) + omega_initializer = kwargs.get( + "omega_initializer", EyeLinearTransformInitializer() + ) self.backbone = LinearTransform( self.hparams["input_dim"], self.hparams["latent_dim"], @@ -232,48 +236,46 @@ class SiameseGMLVQ(SiameseGLVQ): class GMLMLVQ(GLVQ): """Generalized Multi-Layer Matrix Learning Vector Quantization. + Masks are applied to the omega layers to achieve sparsity and constrain + learning to certain items of each omega. Implemented as a regular GLVQ network that simply uses a different distance function. This makes it easier to implement a localized variant. """ # Parameters - _omega_0: torch.Tensor - _omega_1: torch.Tensor + _omegas: list[torch.Tensor] + masks: list[torch.Tensor] def __init__(self, hparams, **kwargs): distance_fn = kwargs.pop("distance_fn", ML_omega_distance) super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - omega_initializer = kwargs.get("omega_initializer") masks = kwargs.get("masks") - omega_0 = LLTI(masks[0]).generate(1, 1) - omega_1 = LLTI(masks[1]).generate(1, 1) - self.register_parameter("_omega_0", Parameter(omega_0)) - self.register_parameter("_omega_1", Parameter(omega_1)) - self.mask_0 = masks[0] - self.mask_1 = masks[1] + for i, _mask in enumerate(masks): + self.register_buffer(f"_mask_{i}", _mask) + self._masks = [self.__getattr__(f"_mask_{i}") for i,_ in enumerate(masks)] + self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks]) @property def omega_matrices(self): - return [self._omega_0.detach().cpu(), self._omega_1.detach().cpu()] + return [_omega.detach().cpu() for _omega in self._omegas] @property def lambda_matrix(self): + # TODO update to respective lambda calculation rules. omega = self._omega.detach() # (input_dim, latent_dim) lam = omega @ omega.T return lam.detach().cpu() def compute_distances(self, x): protos, _ = self.proto_layer() - distances = self.distance_layer(x, protos, self._omega_0, - self._omega_1, self.mask_0, - self.mask_1) + distances = self.distance_layer(x, protos, self._omegas, self._masks) return distances def extra_repr(self): - return f"(omega): (shape: {tuple(self._omega.shape)})" + return f"(omegas): (shapes: {[tuple(_omega.shape) for _omega in self._omegas]})" class GMLVQ(GLVQ): @@ -292,10 +294,12 @@ class GMLVQ(GLVQ): super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - omega_initializer = kwargs.get("omega_initializer", - EyeLinearTransformInitializer()) - omega = omega_initializer.generate(self.hparams["input_dim"], - self.hparams["latent_dim"]) + omega_initializer = kwargs.get( + "omega_initializer", EyeLinearTransformInitializer() + ) + omega = omega_initializer.generate( + self.hparams["input_dim"], self.hparams["latent_dim"] + ) self.register_parameter("_omega", Parameter(omega)) @property