2021-04-29 17:14:33 +00:00
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torch
|
2021-05-17 15:00:23 +00:00
|
|
|
from prototorch.functions.competitions import wtac
|
2021-05-03 11:20:49 +00:00
|
|
|
from torch.optim.lr_scheduler import ExponentialLR
|
2021-04-29 17:14:33 +00:00
|
|
|
|
|
|
|
|
2021-05-11 14:13:00 +00:00
|
|
|
class AbstractPrototypeModel(pl.LightningModule):
|
|
|
|
@property
|
|
|
|
def prototypes(self):
|
|
|
|
return self.proto_layer.components.detach().cpu()
|
|
|
|
|
2021-05-12 14:36:22 +00:00
|
|
|
@property
|
|
|
|
def components(self):
|
|
|
|
"""Only an alias for the prototypes."""
|
|
|
|
return self.prototypes
|
|
|
|
|
2021-04-29 17:14:33 +00:00
|
|
|
def configure_optimizers(self):
|
2021-05-11 14:13:00 +00:00
|
|
|
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
2021-05-03 11:20:49 +00:00
|
|
|
scheduler = ExponentialLR(optimizer,
|
|
|
|
gamma=0.99,
|
|
|
|
last_epoch=-1,
|
|
|
|
verbose=False)
|
|
|
|
sch = {
|
|
|
|
"scheduler": scheduler,
|
|
|
|
"interval": "step",
|
|
|
|
} # called after each training step
|
|
|
|
return [optimizer], [sch]
|
2021-05-12 14:36:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
class PrototypeImageModel(pl.LightningModule):
|
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
|
|
|
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
2021-05-17 15:00:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
class SiamesePrototypeModel(pl.LightningModule):
|
|
|
|
def configure_optimizers(self):
|
|
|
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
|
|
|
lr=self.hparams.proto_lr)
|
|
|
|
if list(self.backbone.parameters()):
|
|
|
|
# only add an optimizer is the backbone has trainable parameters
|
|
|
|
# otherwise, the next line fails
|
|
|
|
bb_opt = self.optimizer(self.backbone.parameters(),
|
|
|
|
lr=self.hparams.bb_lr)
|
|
|
|
return proto_opt, bb_opt
|
|
|
|
else:
|
|
|
|
return proto_opt
|
|
|
|
|
|
|
|
def predict_latent(self, x, map_protos=True):
|
|
|
|
"""Predict `x` assuming it is already embedded in the latent space.
|
|
|
|
|
|
|
|
Only the prototypes are embedded in the latent space using the
|
|
|
|
backbone.
|
|
|
|
|
|
|
|
"""
|
|
|
|
# model.eval() # ?!
|
|
|
|
with torch.no_grad():
|
|
|
|
protos, plabels = self.proto_layer()
|
|
|
|
if map_protos:
|
|
|
|
protos = self.backbone(protos)
|
|
|
|
d = self.distance_fn(x, protos)
|
|
|
|
y_pred = wtac(d, plabels)
|
|
|
|
return y_pred
|