Refactor code
This commit is contained in:
parent
0611f81aba
commit
a5e086ce0d
@ -37,33 +37,3 @@ class PrototypeImageModel(pl.LightningModule):
|
|||||||
if return_channels_last:
|
if return_channels_last:
|
||||||
grid = grid.permute((1, 2, 0))
|
grid = grid.permute((1, 2, 0))
|
||||||
return grid.cpu()
|
return grid.cpu()
|
||||||
|
|
||||||
|
|
||||||
class SiamesePrototypeModel(pl.LightningModule):
|
|
||||||
def configure_optimizers(self):
|
|
||||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
|
||||||
lr=self.hparams.proto_lr)
|
|
||||||
if list(self.backbone.parameters()):
|
|
||||||
# only add an optimizer is the backbone has trainable parameters
|
|
||||||
# otherwise, the next line fails
|
|
||||||
bb_opt = self.optimizer(self.backbone.parameters(),
|
|
||||||
lr=self.hparams.bb_lr)
|
|
||||||
return proto_opt, bb_opt
|
|
||||||
else:
|
|
||||||
return proto_opt
|
|
||||||
|
|
||||||
def predict_latent(self, x, map_protos=True):
|
|
||||||
"""Predict `x` assuming it is already embedded in the latent space.
|
|
||||||
|
|
||||||
Only the prototypes are embedded in the latent space using the
|
|
||||||
backbone.
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
protos, plabels = self.proto_layer()
|
|
||||||
if map_protos:
|
|
||||||
protos = self.backbone(protos)
|
|
||||||
d = self.distance_fn(x, protos)
|
|
||||||
y_pred = wtac(d, plabels)
|
|
||||||
return y_pred
|
|
||||||
|
@ -5,8 +5,7 @@ from prototorch.components.components import Components
|
|||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.functions.similarities import cosine_similarity
|
from prototorch.functions.similarities import cosine_similarity
|
||||||
|
|
||||||
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
|
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||||
SiamesePrototypeModel)
|
|
||||||
from .glvq import SiameseGLVQ
|
from .glvq import SiameseGLVQ
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,8 +9,7 @@ from prototorch.functions.helper import get_flat
|
|||||||
from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss,
|
from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss,
|
||||||
lvq1_loss, lvq21_loss)
|
lvq1_loss, lvq21_loss)
|
||||||
|
|
||||||
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
|
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||||
SiamesePrototypeModel)
|
|
||||||
|
|
||||||
|
|
||||||
class GLVQ(AbstractPrototypeModel):
|
class GLVQ(AbstractPrototypeModel):
|
||||||
@ -130,7 +129,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
return f"{super_repr}"
|
return f"{super_repr}"
|
||||||
|
|
||||||
|
|
||||||
class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
|
class SiameseGLVQ(GLVQ):
|
||||||
"""GLVQ in a Siamese setting.
|
"""GLVQ in a Siamese setting.
|
||||||
|
|
||||||
GLVQ model that applies an arbitrary transformation on the inputs and the
|
GLVQ model that applies an arbitrary transformation on the inputs and the
|
||||||
@ -148,6 +147,18 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
|
|||||||
self.both_path_gradients = both_path_gradients
|
self.both_path_gradients = both_path_gradients
|
||||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||||
|
lr=self.hparams.proto_lr)
|
||||||
|
if list(self.backbone.parameters()):
|
||||||
|
# only add an optimizer is the backbone has trainable parameters
|
||||||
|
# otherwise, the next line fails
|
||||||
|
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||||
|
lr=self.hparams.bb_lr)
|
||||||
|
return proto_opt, bb_opt
|
||||||
|
else:
|
||||||
|
return proto_opt
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
@ -157,6 +168,22 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
|
|||||||
distances = self.distance_fn(latent_x, latent_protos)
|
distances = self.distance_fn(latent_x, latent_protos)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
def predict_latent(self, x, map_protos=True):
|
||||||
|
"""Predict `x` assuming it is already embedded in the latent space.
|
||||||
|
|
||||||
|
Only the prototypes are embedded in the latent space using the
|
||||||
|
backbone.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
protos, plabels = self.proto_layer()
|
||||||
|
if map_protos:
|
||||||
|
protos = self.backbone(protos)
|
||||||
|
d = self.distance_fn(x, protos)
|
||||||
|
y_pred = wtac(d, plabels)
|
||||||
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
class GRLVQ(SiameseGLVQ):
|
class GRLVQ(SiameseGLVQ):
|
||||||
"""Generalized Relevance Learning Vector Quantization."""
|
"""Generalized Relevance Learning Vector Quantization."""
|
||||||
|
Loading…
Reference in New Issue
Block a user