Refactor code
This commit is contained in:
parent
0611f81aba
commit
a5e086ce0d
@ -37,33 +37,3 @@ class PrototypeImageModel(pl.LightningModule):
|
||||
if return_channels_last:
|
||||
grid = grid.permute((1, 2, 0))
|
||||
return grid.cpu()
|
||||
|
||||
|
||||
class SiamesePrototypeModel(pl.LightningModule):
|
||||
def configure_optimizers(self):
|
||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||
lr=self.hparams.proto_lr)
|
||||
if list(self.backbone.parameters()):
|
||||
# only add an optimizer is the backbone has trainable parameters
|
||||
# otherwise, the next line fails
|
||||
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||
lr=self.hparams.bb_lr)
|
||||
return proto_opt, bb_opt
|
||||
else:
|
||||
return proto_opt
|
||||
|
||||
def predict_latent(self, x, map_protos=True):
|
||||
"""Predict `x` assuming it is already embedded in the latent space.
|
||||
|
||||
Only the prototypes are embedded in the latent space using the
|
||||
backbone.
|
||||
|
||||
"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
protos, plabels = self.proto_layer()
|
||||
if map_protos:
|
||||
protos = self.backbone(protos)
|
||||
d = self.distance_fn(x, protos)
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred
|
||||
|
@ -5,8 +5,7 @@ from prototorch.components.components import Components
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.similarities import cosine_similarity
|
||||
|
||||
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
|
||||
SiamesePrototypeModel)
|
||||
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||
from .glvq import SiameseGLVQ
|
||||
|
||||
|
||||
|
@ -9,8 +9,7 @@ from prototorch.functions.helper import get_flat
|
||||
from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss,
|
||||
lvq1_loss, lvq21_loss)
|
||||
|
||||
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
|
||||
SiamesePrototypeModel)
|
||||
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||
|
||||
|
||||
class GLVQ(AbstractPrototypeModel):
|
||||
@ -130,7 +129,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
return f"{super_repr}"
|
||||
|
||||
|
||||
class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
|
||||
class SiameseGLVQ(GLVQ):
|
||||
"""GLVQ in a Siamese setting.
|
||||
|
||||
GLVQ model that applies an arbitrary transformation on the inputs and the
|
||||
@ -148,6 +147,18 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
|
||||
self.both_path_gradients = both_path_gradients
|
||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||
|
||||
def configure_optimizers(self):
|
||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||
lr=self.hparams.proto_lr)
|
||||
if list(self.backbone.parameters()):
|
||||
# only add an optimizer is the backbone has trainable parameters
|
||||
# otherwise, the next line fails
|
||||
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||
lr=self.hparams.bb_lr)
|
||||
return proto_opt, bb_opt
|
||||
else:
|
||||
return proto_opt
|
||||
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
@ -157,6 +168,22 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
|
||||
distances = self.distance_fn(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
def predict_latent(self, x, map_protos=True):
|
||||
"""Predict `x` assuming it is already embedded in the latent space.
|
||||
|
||||
Only the prototypes are embedded in the latent space using the
|
||||
backbone.
|
||||
|
||||
"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
protos, plabels = self.proto_layer()
|
||||
if map_protos:
|
||||
protos = self.backbone(protos)
|
||||
d = self.distance_fn(x, protos)
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred
|
||||
|
||||
|
||||
class GRLVQ(SiameseGLVQ):
|
||||
"""Generalized Relevance Learning Vector Quantization."""
|
||||
|
Loading…
Reference in New Issue
Block a user