Cleanup models
Siamese architectures no longer accept a `backbone_module`. They have to be initialized with an pre-initialized backbone object instead. This is so that the visualization callbacks could use the very same object for visualization purposes. Also, there's no longer a dependent copy of the backbone. It is managed simply with `requires_grad` instead.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from prototorch.functions.competitions import wtac
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
|
||||
@@ -29,3 +30,33 @@ class AbstractPrototypeModel(pl.LightningModule):
|
||||
class PrototypeImageModel(pl.LightningModule):
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
||||
|
||||
class SiamesePrototypeModel(pl.LightningModule):
|
||||
def configure_optimizers(self):
|
||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||
lr=self.hparams.proto_lr)
|
||||
if list(self.backbone.parameters()):
|
||||
# only add an optimizer is the backbone has trainable parameters
|
||||
# otherwise, the next line fails
|
||||
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||
lr=self.hparams.bb_lr)
|
||||
return proto_opt, bb_opt
|
||||
else:
|
||||
return proto_opt
|
||||
|
||||
def predict_latent(self, x, map_protos=True):
|
||||
"""Predict `x` assuming it is already embedded in the latent space.
|
||||
|
||||
Only the prototypes are embedded in the latent space using the
|
||||
backbone.
|
||||
|
||||
"""
|
||||
# model.eval() # ?!
|
||||
with torch.no_grad():
|
||||
protos, plabels = self.proto_layer()
|
||||
if map_protos:
|
||||
protos = self.backbone(protos)
|
||||
d = self.distance_fn(x, protos)
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred
|
||||
|
Reference in New Issue
Block a user