prototorch_models/prototorch/models/abstract.py
Jensun Ravichandran 81346785bd Cleanup models
Siamese architectures no longer accept a `backbone_module`. They have to be
initialized with an pre-initialized backbone object instead. This is so that the
visualization callbacks could use the very same object for visualization
purposes. Also, there's no longer a dependent copy of the backbone. It is
managed simply with `requires_grad` instead.
2021-05-17 17:00:23 +02:00

63 lines
2.1 KiB
Python

import pytorch_lightning as pl
import torch
from prototorch.functions.competitions import wtac
from torch.optim.lr_scheduler import ExponentialLR
class AbstractPrototypeModel(pl.LightningModule):
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class PrototypeImageModel(pl.LightningModule):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.components.data.clamp_(0.0, 1.0)
class SiamesePrototypeModel(pl.LightningModule):
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams.proto_lr)
if list(self.backbone.parameters()):
# only add an optimizer is the backbone has trainable parameters
# otherwise, the next line fails
bb_opt = self.optimizer(self.backbone.parameters(),
lr=self.hparams.bb_lr)
return proto_opt, bb_opt
else:
return proto_opt
def predict_latent(self, x, map_protos=True):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
# model.eval() # ?!
with torch.no_grad():
protos, plabels = self.proto_layer()
if map_protos:
protos = self.backbone(protos)
d = self.distance_fn(x, protos)
y_pred = wtac(d, plabels)
return y_pred