Add siamese glvq

This commit is contained in:
Jensun Ravichandran
2021-04-27 14:35:17 +02:00
parent 8d57f69c9e
commit 1fb197077c
3 changed files with 170 additions and 12 deletions

View File

@@ -68,8 +68,8 @@ class GLVQ(pl.LightningModule):
# self.log("train_acc_epoch", self.train_acc.compute())
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
# model.eval() # ?!
d = self(x)
plabels = self.proto_layer.prototype_labels
y_pred = wtac(d, plabels)
@@ -77,8 +77,52 @@ class GLVQ(pl.LightningModule):
class ImageGLVQ(GLVQ):
"""GLVQ model that constrains the prototypes to the range [0, 1] by
"""GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by
clamping after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting.
GLVQ model that applies an arbitrary transformation on the inputs and the
prototypes before computing the distances between them. The weights in the
transformation pipeline are only learned from the inputs.
"""
def __init__(self,
hparams,
backbone_module=torch.nn.Identity,
backbone_params={},
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
self.backbone_dependent = backbone_module(
**backbone_params).requires_grad_(False)
def sync_backbones(self):
master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True)
def forward(self, x):
self.sync_backbones()
protos = self.proto_layer.prototypes
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
dis = euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
# model.eval() # ?!
with torch.no_grad():
protos = self.proto_layer.prototypes
latent_protos = self.backbone_dependent(protos)
d = euclidean_distance(x, latent_protos)
plabels = self.proto_layer.prototype_labels
y_pred = wtac(d, plabels)
return y_pred.numpy()