Refactor code
This commit is contained in:
@@ -37,33 +37,3 @@ class PrototypeImageModel(pl.LightningModule):
|
||||
if return_channels_last:
|
||||
grid = grid.permute((1, 2, 0))
|
||||
return grid.cpu()
|
||||
|
||||
|
||||
class SiamesePrototypeModel(pl.LightningModule):
|
||||
def configure_optimizers(self):
|
||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||
lr=self.hparams.proto_lr)
|
||||
if list(self.backbone.parameters()):
|
||||
# only add an optimizer is the backbone has trainable parameters
|
||||
# otherwise, the next line fails
|
||||
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||
lr=self.hparams.bb_lr)
|
||||
return proto_opt, bb_opt
|
||||
else:
|
||||
return proto_opt
|
||||
|
||||
def predict_latent(self, x, map_protos=True):
|
||||
"""Predict `x` assuming it is already embedded in the latent space.
|
||||
|
||||
Only the prototypes are embedded in the latent space using the
|
||||
backbone.
|
||||
|
||||
"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
protos, plabels = self.proto_layer()
|
||||
if map_protos:
|
||||
protos = self.backbone(protos)
|
||||
d = self.distance_fn(x, protos)
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred
|
||||
|
Reference in New Issue
Block a user