From a5e086ce0d488916727f1dc98c6c6c91925578b9 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 21 May 2021 13:33:57 +0200 Subject: [PATCH] Refactor code --- prototorch/models/abstract.py | 30 ------------------------------ prototorch/models/cbc.py | 3 +-- prototorch/models/glvq.py | 33 ++++++++++++++++++++++++++++++--- 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index ca21908..a80189e 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -37,33 +37,3 @@ class PrototypeImageModel(pl.LightningModule): if return_channels_last: grid = grid.permute((1, 2, 0)) return grid.cpu() - - -class SiamesePrototypeModel(pl.LightningModule): - def configure_optimizers(self): - proto_opt = self.optimizer(self.proto_layer.parameters(), - lr=self.hparams.proto_lr) - if list(self.backbone.parameters()): - # only add an optimizer is the backbone has trainable parameters - # otherwise, the next line fails - bb_opt = self.optimizer(self.backbone.parameters(), - lr=self.hparams.bb_lr) - return proto_opt, bb_opt - else: - return proto_opt - - def predict_latent(self, x, map_protos=True): - """Predict `x` assuming it is already embedded in the latent space. - - Only the prototypes are embedded in the latent space using the - backbone. - - """ - self.eval() - with torch.no_grad(): - protos, plabels = self.proto_layer() - if map_protos: - protos = self.backbone(protos) - d = self.distance_fn(x, protos) - y_pred = wtac(d, plabels) - return y_pred diff --git a/prototorch/models/cbc.py b/prototorch/models/cbc.py index 64b48c1..3862eba 100644 --- a/prototorch/models/cbc.py +++ b/prototorch/models/cbc.py @@ -5,8 +5,7 @@ from prototorch.components.components import Components from prototorch.functions.distances import euclidean_distance from prototorch.functions.similarities import cosine_similarity -from .abstract import (AbstractPrototypeModel, PrototypeImageModel, - SiamesePrototypeModel) +from .abstract import AbstractPrototypeModel, PrototypeImageModel from .glvq import SiameseGLVQ diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index e34d023..781b809 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -9,8 +9,7 @@ from prototorch.functions.helper import get_flat from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss, lvq1_loss, lvq21_loss) -from .abstract import (AbstractPrototypeModel, PrototypeImageModel, - SiamesePrototypeModel) +from .abstract import AbstractPrototypeModel, PrototypeImageModel class GLVQ(AbstractPrototypeModel): @@ -130,7 +129,7 @@ class GLVQ(AbstractPrototypeModel): return f"{super_repr}" -class SiameseGLVQ(SiamesePrototypeModel, GLVQ): +class SiameseGLVQ(GLVQ): """GLVQ in a Siamese setting. GLVQ model that applies an arbitrary transformation on the inputs and the @@ -148,6 +147,18 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ): self.both_path_gradients = both_path_gradients self.distance_fn = kwargs.get("distance_fn", sed) + def configure_optimizers(self): + proto_opt = self.optimizer(self.proto_layer.parameters(), + lr=self.hparams.proto_lr) + if list(self.backbone.parameters()): + # only add an optimizer is the backbone has trainable parameters + # otherwise, the next line fails + bb_opt = self.optimizer(self.backbone.parameters(), + lr=self.hparams.bb_lr) + return proto_opt, bb_opt + else: + return proto_opt + def _forward(self, x): protos, _ = self.proto_layer() latent_x = self.backbone(x) @@ -157,6 +168,22 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ): distances = self.distance_fn(latent_x, latent_protos) return distances + def predict_latent(self, x, map_protos=True): + """Predict `x` assuming it is already embedded in the latent space. + + Only the prototypes are embedded in the latent space using the + backbone. + + """ + self.eval() + with torch.no_grad(): + protos, plabels = self.proto_layer() + if map_protos: + protos = self.backbone(protos) + d = self.distance_fn(x, protos) + y_pred = wtac(d, plabels) + return y_pred + class GRLVQ(SiameseGLVQ): """Generalized Relevance Learning Vector Quantization."""