prototorch_models/prototorch/models/abstract.py

238 lines
7.7 KiB
Python
Raw Normal View History

2021-06-04 20:20:32 +00:00
"""Abstract classes to be inherited by prototorch models."""
import logging
2021-04-29 17:14:33 +00:00
import pytorch_lightning as pl
2021-06-04 20:20:32 +00:00
import torch
import torch.nn.functional as F
2021-06-04 20:20:32 +00:00
import torchmetrics
2022-05-16 09:12:53 +00:00
from prototorch.core.competitions import WTAC
from prototorch.core.components import (
AbstractComponents,
Components,
LabeledComponents,
)
2022-05-16 09:12:53 +00:00
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
LabelsInitializer,
ZerosCompInitializer,
)
from prototorch.core.pooling import stratified_min_pooling
from prototorch.nn.wrappers import LambdaLayer
2021-06-04 20:20:32 +00:00
class ProtoTorchBolt(pl.LightningModule):
2021-06-09 16:21:12 +00:00
"""All ProtoTorch models are ProtoTorch Bolts."""
2021-06-04 20:20:32 +00:00
def __init__(self, hparams, **kwargs):
super().__init__()
# Hyperparameters
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("lr", 0.01)
# Default config
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
2021-04-29 17:14:33 +00:00
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams["lr"])
2021-06-04 13:55:06 +00:00
if self.lr_scheduler is not None:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
else:
return optimizer
2021-06-04 20:20:32 +00:00
def reconfigure_optimizers(self):
if self.trainer:
self.trainer.strategy.setup_optimizers(self.trainer)
else:
logging.warning("No trainer to reconfigure optimizers!")
2021-06-04 20:20:32 +00:00
2021-07-14 17:17:05 +00:00
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
wrapped = f"ProtoTorch Bolt(\n{indented})"
return wrapped
class PrototypeModel(ProtoTorchBolt):
proto_layer: AbstractComponents
2021-07-14 17:17:05 +00:00
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
@property
def num_prototypes(self):
return len(self.proto_layer.components)
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
2021-06-04 20:20:32 +00:00
def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
self.hparams["distribution"] = self.proto_layer.distribution
2021-06-04 20:20:32 +00:00
self.reconfigure_optimizers()
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.hparams["distribution"] = self.proto_layer.distribution
2021-06-04 20:20:32 +00:00
self.reconfigure_optimizers()
class UnsupervisedPrototypeModel(PrototypeModel):
proto_layer: Components
2021-06-04 20:20:32 +00:00
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
2021-06-14 18:09:41 +00:00
prototypes_initializer = kwargs.get("prototypes_initializer", None)
if prototypes_initializer is not None:
2021-06-04 20:20:32 +00:00
self.proto_layer = Components(
self.hparams["num_prototypes"],
2021-06-14 18:09:41 +00:00
initializer=prototypes_initializer,
2021-06-04 20:20:32 +00:00
)
def compute_distances(self, x):
protos = self.proto_layer().type_as(x)
2021-06-04 20:20:32 +00:00
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self.compute_distances(x)
return distances
class SupervisedPrototypeModel(PrototypeModel):
proto_layer: LabeledComponents
2022-03-30 13:10:06 +00:00
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
2021-06-04 20:20:32 +00:00
super().__init__(hparams, **kwargs)
# Layers
2022-03-30 13:10:06 +00:00
distribution = hparams.get("distribution", None)
2021-06-14 18:09:41 +00:00
prototypes_initializer = kwargs.get("prototypes_initializer", None)
2021-06-14 18:42:57 +00:00
labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer())
2022-03-30 13:10:06 +00:00
if not skip_proto_layer:
# when subclasses do not need a customized prototype layer
if prototypes_initializer is not None:
# when building a new model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
proto_shape = self.proto_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
2022-03-30 13:10:06 +00:00
else:
# when restoring a checkpointed model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
2022-03-30 13:10:06 +00:00
)
2021-06-04 20:20:32 +00:00
self.competition_layer = WTAC()
@property
def prototype_labels(self):
2021-06-14 18:09:41 +00:00
return self.proto_layer.labels.detach().cpu()
2021-06-04 20:20:32 +00:00
@property
def num_classes(self):
return self.proto_layer.num_classes
2021-06-04 20:20:32 +00:00
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self.compute_distances(x)
_, plabels = self.proto_layer()
2021-06-14 18:09:41 +00:00
winning = stratified_min_pooling(distances, plabels)
y_pred = F.softmin(winning, dim=1)
2021-06-04 20:20:32 +00:00
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
_, plabels = self.proto_layer()
2021-06-04 20:20:32 +00:00
y_pred = self.competition_layer(distances, plabels)
return y_pred
def predict(self, x):
with torch.no_grad():
distances = self.compute_distances(x)
y_pred = self.predict_from_distances(distances)
return y_pred
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
2021-09-10 17:19:51 +00:00
def test_step(self, batch, batch_idx):
x, targets = batch
preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
self.log("test_acc", accuracy)
2021-06-04 20:20:32 +00:00
2021-07-14 17:17:05 +00:00
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
2021-06-09 16:21:12 +00:00
class NonGradientMixin(ProtoTorchMixin):
2021-06-04 20:20:32 +00:00
"""Mixin for custom non-gradient optimization."""
2021-06-04 20:20:32 +00:00
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
2021-06-04 20:20:32 +00:00
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
2021-06-09 16:21:12 +00:00
class ImagePrototypesMixin(ProtoTorchMixin):
2021-06-04 20:20:32 +00:00
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx):
2021-06-04 20:20:32 +00:00
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
2021-05-25 13:41:10 +00:00
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
2021-05-21 11:11:48 +00:00
from torchvision.utils import make_grid
2021-05-25 13:41:10 +00:00
grid = make_grid(self.components, nrow=num_columns)
2021-05-21 11:11:48 +00:00
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()