Add siamese glvq
This commit is contained in:
@@ -68,8 +68,8 @@ class GLVQ(pl.LightningModule):
|
||||
# self.log("train_acc_epoch", self.train_acc.compute())
|
||||
|
||||
def predict(self, x):
|
||||
# model.eval() # ?!
|
||||
with torch.no_grad():
|
||||
# model.eval() # ?!
|
||||
d = self(x)
|
||||
plabels = self.proto_layer.prototype_labels
|
||||
y_pred = wtac(d, plabels)
|
||||
@@ -77,8 +77,52 @@ class GLVQ(pl.LightningModule):
|
||||
|
||||
|
||||
class ImageGLVQ(GLVQ):
|
||||
"""GLVQ model that constrains the prototypes to the range [0, 1] by
|
||||
"""GLVQ for training on image data.
|
||||
|
||||
GLVQ model that constrains the prototypes to the range [0, 1] by
|
||||
clamping after updates.
|
||||
"""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
|
||||
|
||||
|
||||
class SiameseGLVQ(GLVQ):
|
||||
"""GLVQ in a Siamese setting.
|
||||
|
||||
GLVQ model that applies an arbitrary transformation on the inputs and the
|
||||
prototypes before computing the distances between them. The weights in the
|
||||
transformation pipeline are only learned from the inputs.
|
||||
"""
|
||||
def __init__(self,
|
||||
hparams,
|
||||
backbone_module=torch.nn.Identity,
|
||||
backbone_params={},
|
||||
**kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.backbone = backbone_module(**backbone_params)
|
||||
self.backbone_dependent = backbone_module(
|
||||
**backbone_params).requires_grad_(False)
|
||||
|
||||
def sync_backbones(self):
|
||||
master_state = self.backbone.state_dict()
|
||||
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
||||
|
||||
def forward(self, x):
|
||||
self.sync_backbones()
|
||||
protos = self.proto_layer.prototypes
|
||||
|
||||
latent_x = self.backbone(x)
|
||||
latent_protos = self.backbone_dependent(protos)
|
||||
|
||||
dis = euclidean_distance(latent_x, latent_protos)
|
||||
return dis
|
||||
|
||||
def predict_latent(self, x):
|
||||
# model.eval() # ?!
|
||||
with torch.no_grad():
|
||||
protos = self.proto_layer.prototypes
|
||||
latent_protos = self.backbone_dependent(protos)
|
||||
d = euclidean_distance(x, latent_protos)
|
||||
plabels = self.proto_layer.prototype_labels
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred.numpy()
|
||||
|
Reference in New Issue
Block a user