Implement a prototypical 2-layer version of GMLVQ
This commit is contained in:
parent
d4bf6dbbe9
commit
824dfced92
@ -5,9 +5,10 @@ from prototorch.core.competitions import wtac
|
|||||||
from prototorch.core.distances import (
|
from prototorch.core.distances import (
|
||||||
lomega_distance,
|
lomega_distance,
|
||||||
omega_distance,
|
omega_distance,
|
||||||
|
ML_omega_distance,
|
||||||
squared_euclidean_distance,
|
squared_euclidean_distance,
|
||||||
)
|
)
|
||||||
from prototorch.core.initializers import EyeLinearTransformInitializer
|
from prototorch.core.initializers import (EyeLinearTransformInitializer, LLTI)
|
||||||
from prototorch.core.losses import (
|
from prototorch.core.losses import (
|
||||||
GLVQLoss,
|
GLVQLoss,
|
||||||
lvq1_loss,
|
lvq1_loss,
|
||||||
@ -229,6 +230,51 @@ class SiameseGMLVQ(SiameseGLVQ):
|
|||||||
return lam.detach().cpu()
|
return lam.detach().cpu()
|
||||||
|
|
||||||
|
|
||||||
|
class GMLMLVQ(GLVQ):
|
||||||
|
"""Generalized Multi-Layer Matrix Learning Vector Quantization.
|
||||||
|
|
||||||
|
Implemented as a regular GLVQ network that simply uses a different distance
|
||||||
|
function. This makes it easier to implement a localized variant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
_omega: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
|
||||||
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
|
# Additional parameters
|
||||||
|
omega_initializer = kwargs.get("omega_initializer")
|
||||||
|
masks = kwargs.get("masks")
|
||||||
|
omega_0 = LLTI(masks[0]).generate(1, 1)
|
||||||
|
omega_1 = LLTI(masks[1]).generate(1, 1)
|
||||||
|
self.register_parameter("_omega_0", Parameter(omega_0))
|
||||||
|
self.register_parameter("_omega_1", Parameter(omega_1))
|
||||||
|
self.mask_0 = masks[0]
|
||||||
|
self.mask_1 = masks[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def omega_matrix(self):
|
||||||
|
return self._omega.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lambda_matrix(self):
|
||||||
|
omega = self._omega.detach() # (input_dim, latent_dim)
|
||||||
|
lam = omega @ omega.T
|
||||||
|
return lam.detach().cpu()
|
||||||
|
|
||||||
|
def compute_distances(self, x):
|
||||||
|
protos, _ = self.proto_layer()
|
||||||
|
distances = self.distance_layer(x, protos, self._omega_0,
|
||||||
|
self._omega_1, self.mask_0,
|
||||||
|
self.mask_1)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
||||||
|
|
||||||
|
|
||||||
class GMLVQ(GLVQ):
|
class GMLVQ(GLVQ):
|
||||||
"""Generalized Matrix Learning Vector Quantization.
|
"""Generalized Matrix Learning Vector Quantization.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user