36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
|
import pytorch_lightning as pl
|
||
|
import torch
|
||
|
from prototorch.core.components import Components
|
||
|
|
||
|
|
||
|
class ProtoTorchMixin(pl.LightningModule):
|
||
|
"""All mixins are ProtoTorchMixins."""
|
||
|
|
||
|
|
||
|
class NonGradientMixin(ProtoTorchMixin):
|
||
|
"""Mixin for custom non-gradient optimization."""
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
self.automatic_optimization = False
|
||
|
|
||
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
class ImagePrototypesMixin(ProtoTorchMixin):
|
||
|
"""Mixin for models with image prototypes."""
|
||
|
proto_layer: Components
|
||
|
components: torch.Tensor
|
||
|
|
||
|
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||
|
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||
|
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||
|
|
||
|
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
||
|
from torchvision.utils import make_grid
|
||
|
grid = make_grid(self.components, nrow=num_columns)
|
||
|
if return_channels_last:
|
||
|
grid = grid.permute((1, 2, 0))
|
||
|
return grid.cpu()
|