Refactor code

This commit is contained in:
Jensun Ravichandran 2021-05-21 13:33:57 +02:00
parent 0611f81aba
commit a5e086ce0d
3 changed files with 31 additions and 35 deletions

View File

@ -37,33 +37,3 @@ class PrototypeImageModel(pl.LightningModule):
if return_channels_last: if return_channels_last:
grid = grid.permute((1, 2, 0)) grid = grid.permute((1, 2, 0))
return grid.cpu() return grid.cpu()
class SiamesePrototypeModel(pl.LightningModule):
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams.proto_lr)
if list(self.backbone.parameters()):
# only add an optimizer is the backbone has trainable parameters
# otherwise, the next line fails
bb_opt = self.optimizer(self.backbone.parameters(),
lr=self.hparams.bb_lr)
return proto_opt, bb_opt
else:
return proto_opt
def predict_latent(self, x, map_protos=True):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
self.eval()
with torch.no_grad():
protos, plabels = self.proto_layer()
if map_protos:
protos = self.backbone(protos)
d = self.distance_fn(x, protos)
y_pred = wtac(d, plabels)
return y_pred

View File

@ -5,8 +5,7 @@ from prototorch.components.components import Components
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity from prototorch.functions.similarities import cosine_similarity
from .abstract import (AbstractPrototypeModel, PrototypeImageModel, from .abstract import AbstractPrototypeModel, PrototypeImageModel
SiamesePrototypeModel)
from .glvq import SiameseGLVQ from .glvq import SiameseGLVQ

View File

@ -9,8 +9,7 @@ from prototorch.functions.helper import get_flat
from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss, from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss,
lvq1_loss, lvq21_loss) lvq1_loss, lvq21_loss)
from .abstract import (AbstractPrototypeModel, PrototypeImageModel, from .abstract import AbstractPrototypeModel, PrototypeImageModel
SiamesePrototypeModel)
class GLVQ(AbstractPrototypeModel): class GLVQ(AbstractPrototypeModel):
@ -130,7 +129,7 @@ class GLVQ(AbstractPrototypeModel):
return f"{super_repr}" return f"{super_repr}"
class SiameseGLVQ(SiamesePrototypeModel, GLVQ): class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting. """GLVQ in a Siamese setting.
GLVQ model that applies an arbitrary transformation on the inputs and the GLVQ model that applies an arbitrary transformation on the inputs and the
@ -148,6 +147,18 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
self.both_path_gradients = both_path_gradients self.both_path_gradients = both_path_gradients
self.distance_fn = kwargs.get("distance_fn", sed) self.distance_fn = kwargs.get("distance_fn", sed)
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams.proto_lr)
if list(self.backbone.parameters()):
# only add an optimizer is the backbone has trainable parameters
# otherwise, the next line fails
bb_opt = self.optimizer(self.backbone.parameters(),
lr=self.hparams.bb_lr)
return proto_opt, bb_opt
else:
return proto_opt
def _forward(self, x): def _forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
@ -157,6 +168,22 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
distances = self.distance_fn(latent_x, latent_protos) distances = self.distance_fn(latent_x, latent_protos)
return distances return distances
def predict_latent(self, x, map_protos=True):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
self.eval()
with torch.no_grad():
protos, plabels = self.proto_layer()
if map_protos:
protos = self.backbone(protos)
d = self.distance_fn(x, protos)
y_pred = wtac(d, plabels)
return y_pred
class GRLVQ(SiameseGLVQ): class GRLVQ(SiameseGLVQ):
"""Generalized Relevance Learning Vector Quantization.""" """Generalized Relevance Learning Vector Quantization."""