Implement a prototypical 2-layer version of GMLVQ
This commit is contained in:
parent
d4bf6dbbe9
commit
824dfced92
@ -5,9 +5,10 @@ from prototorch.core.competitions import wtac
|
||||
from prototorch.core.distances import (
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
ML_omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from prototorch.core.initializers import EyeLinearTransformInitializer
|
||||
from prototorch.core.initializers import (EyeLinearTransformInitializer, LLTI)
|
||||
from prototorch.core.losses import (
|
||||
GLVQLoss,
|
||||
lvq1_loss,
|
||||
@ -229,6 +230,51 @@ class SiameseGMLVQ(SiameseGLVQ):
|
||||
return lam.detach().cpu()
|
||||
|
||||
|
||||
class GMLMLVQ(GLVQ):
|
||||
"""Generalized Multi-Layer Matrix Learning Vector Quantization.
|
||||
|
||||
Implemented as a regular GLVQ network that simply uses a different distance
|
||||
function. This makes it easier to implement a localized variant.
|
||||
"""
|
||||
|
||||
# Parameters
|
||||
_omega: torch.Tensor
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
# Additional parameters
|
||||
omega_initializer = kwargs.get("omega_initializer")
|
||||
masks = kwargs.get("masks")
|
||||
omega_0 = LLTI(masks[0]).generate(1, 1)
|
||||
omega_1 = LLTI(masks[1]).generate(1, 1)
|
||||
self.register_parameter("_omega_0", Parameter(omega_0))
|
||||
self.register_parameter("_omega_1", Parameter(omega_1))
|
||||
self.mask_0 = masks[0]
|
||||
self.mask_1 = masks[1]
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
return self._omega.detach().cpu()
|
||||
|
||||
@property
|
||||
def lambda_matrix(self):
|
||||
omega = self._omega.detach() # (input_dim, latent_dim)
|
||||
lam = omega @ omega.T
|
||||
return lam.detach().cpu()
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos, self._omega_0,
|
||||
self._omega_1, self.mask_0,
|
||||
self.mask_1)
|
||||
return distances
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
||||
|
||||
|
||||
class GMLVQ(GLVQ):
|
||||
"""Generalized Matrix Learning Vector Quantization.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user