Implement a prototypical 2-layer version of GMLVQ

This commit is contained in:
julius 2023-11-03 14:59:00 +01:00
parent d4bf6dbbe9
commit 824dfced92
Signed by untrusted user who does not match committer: julius
GPG Key ID: 8AA3791362A8084A

View File

@ -5,9 +5,10 @@ from prototorch.core.competitions import wtac
from prototorch.core.distances import ( from prototorch.core.distances import (
lomega_distance, lomega_distance,
omega_distance, omega_distance,
ML_omega_distance,
squared_euclidean_distance, squared_euclidean_distance,
) )
from prototorch.core.initializers import EyeLinearTransformInitializer from prototorch.core.initializers import (EyeLinearTransformInitializer, LLTI)
from prototorch.core.losses import ( from prototorch.core.losses import (
GLVQLoss, GLVQLoss,
lvq1_loss, lvq1_loss,
@ -229,6 +230,51 @@ class SiameseGMLVQ(SiameseGLVQ):
return lam.detach().cpu() return lam.detach().cpu()
class GMLMLVQ(GLVQ):
"""Generalized Multi-Layer Matrix Learning Vector Quantization.
Implemented as a regular GLVQ network that simply uses a different distance
function. This makes it easier to implement a localized variant.
"""
# Parameters
_omega: torch.Tensor
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
omega_initializer = kwargs.get("omega_initializer")
masks = kwargs.get("masks")
omega_0 = LLTI(masks[0]).generate(1, 1)
omega_1 = LLTI(masks[1]).generate(1, 1)
self.register_parameter("_omega_0", Parameter(omega_0))
self.register_parameter("_omega_1", Parameter(omega_1))
self.mask_0 = masks[0]
self.mask_1 = masks[1]
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach() # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega_0,
self._omega_1, self.mask_0,
self.mask_1)
return distances
def extra_repr(self):
return f"(omega): (shape: {tuple(self._omega.shape)})"
class GMLVQ(GLVQ): class GMLVQ(GLVQ):
"""Generalized Matrix Learning Vector Quantization. """Generalized Matrix Learning Vector Quantization.