feat: ImageGTLVQ and SiameseGTLVQ with examples
This commit is contained in:
committed by
Jensun Ravichandran
parent
d3bb430104
commit
a9edf06507
@@ -13,8 +13,10 @@ from .glvq import (
|
||||
LVQMLN,
|
||||
ImageGLVQ,
|
||||
ImageGMLVQ,
|
||||
ImageGTLVQ,
|
||||
SiameseGLVQ,
|
||||
SiameseGMLVQ,
|
||||
SiameseGTLVQ,
|
||||
)
|
||||
from .knn import KNN
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
|
@@ -284,22 +284,28 @@ class LGMLVQ(GMLVQ):
|
||||
|
||||
|
||||
class GTLVQ(LGMLVQ):
|
||||
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
||||
"""Localized and Generalized Tangent Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
omega_initializer = kwargs.get("omega_initializer")
|
||||
omega = omega_initializer.generate(self.hparams.input_dim,
|
||||
self.hparams.latent_dim)
|
||||
|
||||
if omega_initializer is not None:
|
||||
subspace = omega_initializer.generate(self.hparams.input_dim,
|
||||
self.hparams.latent_dim)
|
||||
omega = torch.repeat_interleave(subspace.unsqueeze(0),
|
||||
self.num_prototypes,
|
||||
dim=0)
|
||||
else:
|
||||
omega = torch.rand(
|
||||
self.num_prototypes,
|
||||
self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Re-register `_omega` to override the one from the super class.
|
||||
omega = torch.rand(
|
||||
self.num_prototypes,
|
||||
self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
device=self.device,
|
||||
)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
@@ -307,6 +313,14 @@ class GTLVQ(LGMLVQ):
|
||||
self._omega.copy_(orthogonalization(self._omega))
|
||||
|
||||
|
||||
class SiameseGTLVQ(SiameseGLVQ, GTLVQ):
|
||||
"""Generalized Tangent Learning Vector Quantization.
|
||||
|
||||
Implemented as a Siamese network with a linear transformation backbone.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class GLVQ1(GLVQ):
|
||||
"""Generalized Learning Vector Quantization 1."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
@@ -339,3 +353,17 @@ class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
|
||||
after updates.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ImageGTLVQ(ImagePrototypesMixin, GTLVQ):
|
||||
"""GTLVQ for training on image data.
|
||||
|
||||
GTLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
after updates.
|
||||
|
||||
"""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
with torch.no_grad():
|
||||
self._omega.copy_(orthogonalization(self._omega))
|
||||
|
Reference in New Issue
Block a user