2021-06-04 20:20:32 +00:00
|
|
|
"""Abstract classes to be inherited by prototorch models."""
|
|
|
|
|
2021-04-29 17:14:33 +00:00
|
|
|
import pytorch_lightning as pl
|
2021-06-04 20:20:32 +00:00
|
|
|
import torch
|
|
|
|
import torchmetrics
|
2021-06-14 18:08:08 +00:00
|
|
|
|
|
|
|
from ..core.competitions import WTAC
|
|
|
|
from ..core.components import Components, LabeledComponents
|
|
|
|
from ..core.distances import euclidean_distance
|
2022-02-02 20:53:03 +00:00
|
|
|
from ..core.initializers import LabelsInitializer, ZerosCompInitializer
|
2021-06-14 18:08:08 +00:00
|
|
|
from ..core.pooling import stratified_min_pooling
|
|
|
|
from ..nn.wrappers import LambdaLayer
|
2021-06-04 20:20:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ProtoTorchBolt(pl.LightningModule):
|
2021-06-09 16:21:12 +00:00
|
|
|
"""All ProtoTorch models are ProtoTorch Bolts."""
|
2022-01-11 17:28:50 +00:00
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
# Hyperparameters
|
|
|
|
self.save_hyperparameters(hparams)
|
|
|
|
|
|
|
|
# Default hparams
|
|
|
|
self.hparams.setdefault("lr", 0.01)
|
|
|
|
|
|
|
|
# Default config
|
|
|
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
|
|
|
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
|
|
|
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
2021-04-29 17:14:33 +00:00
|
|
|
|
|
|
|
def configure_optimizers(self):
|
2021-05-11 14:13:00 +00:00
|
|
|
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
2021-06-04 13:55:06 +00:00
|
|
|
if self.lr_scheduler is not None:
|
|
|
|
scheduler = self.lr_scheduler(optimizer,
|
|
|
|
**self.lr_scheduler_kwargs)
|
|
|
|
sch = {
|
|
|
|
"scheduler": scheduler,
|
|
|
|
"interval": "step",
|
|
|
|
} # called after each training step
|
|
|
|
return [optimizer], [sch]
|
|
|
|
else:
|
|
|
|
return optimizer
|
2021-05-12 14:36:22 +00:00
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
def reconfigure_optimizers(self):
|
2021-06-30 14:03:45 +00:00
|
|
|
self.trainer.accelerator.setup_optimizers(self.trainer)
|
2021-06-04 20:20:32 +00:00
|
|
|
|
2021-07-14 17:17:05 +00:00
|
|
|
def __repr__(self):
|
|
|
|
surep = super().__repr__()
|
|
|
|
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
|
|
|
wrapped = f"ProtoTorch Bolt(\n{indented})"
|
|
|
|
return wrapped
|
|
|
|
|
|
|
|
|
|
|
|
class PrototypeModel(ProtoTorchBolt):
|
2022-01-11 17:28:50 +00:00
|
|
|
|
2021-07-14 17:17:05 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
|
|
|
|
|
|
|
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
|
|
|
self.distance_layer = LambdaLayer(distance_fn)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def num_prototypes(self):
|
|
|
|
return len(self.proto_layer.components)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def prototypes(self):
|
|
|
|
return self.proto_layer.components.detach().cpu()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def components(self):
|
|
|
|
"""Only an alias for the prototypes."""
|
|
|
|
return self.prototypes
|
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
def add_prototypes(self, *args, **kwargs):
|
|
|
|
self.proto_layer.add_components(*args, **kwargs)
|
2022-02-02 20:53:03 +00:00
|
|
|
self.hparams.distribution = self.proto_layer.distribution
|
2021-06-04 20:20:32 +00:00
|
|
|
self.reconfigure_optimizers()
|
|
|
|
|
|
|
|
def remove_prototypes(self, indices):
|
|
|
|
self.proto_layer.remove_components(indices)
|
2022-02-02 20:53:03 +00:00
|
|
|
self.hparams.distribution = self.proto_layer.distribution
|
2021-06-04 20:20:32 +00:00
|
|
|
self.reconfigure_optimizers()
|
|
|
|
|
|
|
|
|
|
|
|
class UnsupervisedPrototypeModel(PrototypeModel):
|
2022-01-11 17:28:50 +00:00
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
|
|
|
|
|
|
|
# Layers
|
2021-06-14 18:09:41 +00:00
|
|
|
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
|
|
|
if prototypes_initializer is not None:
|
2021-06-04 20:20:32 +00:00
|
|
|
self.proto_layer = Components(
|
|
|
|
self.hparams.num_prototypes,
|
2021-06-14 18:09:41 +00:00
|
|
|
initializer=prototypes_initializer,
|
2021-06-04 20:20:32 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def compute_distances(self, x):
|
2021-08-05 09:20:02 +00:00
|
|
|
protos = self.proto_layer().type_as(x)
|
2021-06-04 20:20:32 +00:00
|
|
|
distances = self.distance_layer(x, protos)
|
|
|
|
return distances
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
distances = self.compute_distances(x)
|
|
|
|
return distances
|
|
|
|
|
|
|
|
|
|
|
|
class SupervisedPrototypeModel(PrototypeModel):
|
2022-01-11 17:28:50 +00:00
|
|
|
|
2022-03-30 13:10:06 +00:00
|
|
|
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
|
2021-06-04 20:20:32 +00:00
|
|
|
super().__init__(hparams, **kwargs)
|
|
|
|
|
|
|
|
# Layers
|
2022-03-30 13:10:06 +00:00
|
|
|
distribution = hparams.get("distribution", None)
|
2021-06-14 18:09:41 +00:00
|
|
|
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
2021-06-14 18:42:57 +00:00
|
|
|
labels_initializer = kwargs.get("labels_initializer",
|
|
|
|
LabelsInitializer())
|
2022-03-30 13:10:06 +00:00
|
|
|
if not skip_proto_layer:
|
|
|
|
# when subclasses do not need a customized prototype layer
|
|
|
|
if prototypes_initializer is not None:
|
|
|
|
# when building a new model
|
|
|
|
self.proto_layer = LabeledComponents(
|
|
|
|
distribution=distribution,
|
|
|
|
components_initializer=prototypes_initializer,
|
|
|
|
labels_initializer=labels_initializer,
|
|
|
|
)
|
|
|
|
proto_shape = self.proto_layer.components.shape[1:]
|
|
|
|
self.hparams.initialized_proto_shape = proto_shape
|
|
|
|
else:
|
|
|
|
# when restoring a checkpointed model
|
|
|
|
self.proto_layer = LabeledComponents(
|
|
|
|
distribution=distribution,
|
|
|
|
components_initializer=ZerosCompInitializer(
|
|
|
|
self.hparams.initialized_proto_shape),
|
|
|
|
)
|
2021-06-04 20:20:32 +00:00
|
|
|
self.competition_layer = WTAC()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def prototype_labels(self):
|
2021-06-14 18:09:41 +00:00
|
|
|
return self.proto_layer.labels.detach().cpu()
|
2021-06-04 20:20:32 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def num_classes(self):
|
2021-06-15 13:59:47 +00:00
|
|
|
return self.proto_layer.num_classes
|
2021-06-04 20:20:32 +00:00
|
|
|
|
|
|
|
def compute_distances(self, x):
|
|
|
|
protos, _ = self.proto_layer()
|
|
|
|
distances = self.distance_layer(x, protos)
|
|
|
|
return distances
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
distances = self.compute_distances(x)
|
2021-08-05 07:14:32 +00:00
|
|
|
_, plabels = self.proto_layer()
|
2021-06-14 18:09:41 +00:00
|
|
|
winning = stratified_min_pooling(distances, plabels)
|
2021-11-16 09:19:31 +00:00
|
|
|
y_pred = torch.nn.functional.softmin(winning, dim=1)
|
2021-06-04 20:20:32 +00:00
|
|
|
return y_pred
|
|
|
|
|
|
|
|
def predict_from_distances(self, distances):
|
|
|
|
with torch.no_grad():
|
2021-08-05 07:14:32 +00:00
|
|
|
_, plabels = self.proto_layer()
|
2021-06-04 20:20:32 +00:00
|
|
|
y_pred = self.competition_layer(distances, plabels)
|
|
|
|
return y_pred
|
|
|
|
|
|
|
|
def predict(self, x):
|
|
|
|
with torch.no_grad():
|
|
|
|
distances = self.compute_distances(x)
|
|
|
|
y_pred = self.predict_from_distances(distances)
|
|
|
|
return y_pred
|
|
|
|
|
|
|
|
def log_acc(self, distances, targets, tag):
|
|
|
|
preds = self.predict_from_distances(distances)
|
|
|
|
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
|
|
|
# `.int()` because FloatTensors are assumed to be class probabilities
|
|
|
|
|
|
|
|
self.log(tag,
|
|
|
|
accuracy,
|
|
|
|
on_step=False,
|
|
|
|
on_epoch=True,
|
|
|
|
prog_bar=True,
|
|
|
|
logger=True)
|
2021-09-10 17:19:51 +00:00
|
|
|
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
x, targets = batch
|
|
|
|
|
|
|
|
preds = self.predict(x)
|
|
|
|
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
|
|
|
|
|
|
|
self.log("test_acc", accuracy)
|
2021-06-04 20:20:32 +00:00
|
|
|
|
|
|
|
|
2021-07-14 17:17:05 +00:00
|
|
|
class ProtoTorchMixin(object):
|
|
|
|
"""All mixins are ProtoTorchMixins."""
|
|
|
|
|
|
|
|
|
2021-06-09 16:21:12 +00:00
|
|
|
class NonGradientMixin(ProtoTorchMixin):
|
2021-06-04 20:20:32 +00:00
|
|
|
"""Mixin for custom non-gradient optimization."""
|
2022-01-11 17:28:50 +00:00
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
2021-08-30 15:15:40 +00:00
|
|
|
self.automatic_optimization = False
|
2021-06-04 20:20:32 +00:00
|
|
|
|
|
|
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2021-05-12 14:36:22 +00:00
|
|
|
|
2021-06-09 16:21:12 +00:00
|
|
|
class ImagePrototypesMixin(ProtoTorchMixin):
|
2021-06-04 20:20:32 +00:00
|
|
|
"""Mixin for models with image prototypes."""
|
2022-01-11 17:28:50 +00:00
|
|
|
|
2021-05-12 14:36:22 +00:00
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
2021-06-04 20:20:32 +00:00
|
|
|
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
2021-05-12 14:36:22 +00:00
|
|
|
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
2021-05-17 15:00:23 +00:00
|
|
|
|
2021-05-25 13:41:10 +00:00
|
|
|
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
2021-05-21 11:11:48 +00:00
|
|
|
from torchvision.utils import make_grid
|
2021-05-25 13:41:10 +00:00
|
|
|
grid = make_grid(self.components, nrow=num_columns)
|
2021-05-21 11:11:48 +00:00
|
|
|
if return_channels_last:
|
|
|
|
grid = grid.permute((1, 2, 0))
|
|
|
|
return grid.cpu()
|