diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 4328b10..1ba9f06 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -5,9 +5,10 @@ from prototorch.core.competitions import wtac from prototorch.core.distances import ( lomega_distance, omega_distance, + ML_omega_distance, squared_euclidean_distance, ) -from prototorch.core.initializers import EyeLinearTransformInitializer +from prototorch.core.initializers import (EyeLinearTransformInitializer, LLTI) from prototorch.core.losses import ( GLVQLoss, lvq1_loss, @@ -229,6 +230,51 @@ class SiameseGMLVQ(SiameseGLVQ): return lam.detach().cpu() +class GMLMLVQ(GLVQ): + """Generalized Multi-Layer Matrix Learning Vector Quantization. + + Implemented as a regular GLVQ network that simply uses a different distance + function. This makes it easier to implement a localized variant. + """ + + # Parameters + _omega: torch.Tensor + + def __init__(self, hparams, **kwargs): + distance_fn = kwargs.pop("distance_fn", ML_omega_distance) + super().__init__(hparams, distance_fn=distance_fn, **kwargs) + + # Additional parameters + omega_initializer = kwargs.get("omega_initializer") + masks = kwargs.get("masks") + omega_0 = LLTI(masks[0]).generate(1, 1) + omega_1 = LLTI(masks[1]).generate(1, 1) + self.register_parameter("_omega_0", Parameter(omega_0)) + self.register_parameter("_omega_1", Parameter(omega_1)) + self.mask_0 = masks[0] + self.mask_1 = masks[1] + + @property + def omega_matrix(self): + return self._omega.detach().cpu() + + @property + def lambda_matrix(self): + omega = self._omega.detach() # (input_dim, latent_dim) + lam = omega @ omega.T + return lam.detach().cpu() + + def compute_distances(self, x): + protos, _ = self.proto_layer() + distances = self.distance_layer(x, protos, self._omega_0, + self._omega_1, self.mask_0, + self.mask_1) + return distances + + def extra_repr(self): + return f"(omega): (shape: {tuple(self._omega.shape)})" + + class GMLVQ(GLVQ): """Generalized Matrix Learning Vector Quantization.