diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index f6886c6..1f22954 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -22,7 +22,16 @@ from prototorch.nn.wrappers import LambdaLayer class ProtoTorchBolt(pl.LightningModule): - """All ProtoTorch models are ProtoTorch Bolts.""" + """All ProtoTorch models are ProtoTorch Bolts. + + hparams: + - lr: learning rate + + kwargs: + - optimizer: optimizer class + - lr_scheduler: learning rate scheduler class + - lr_scheduler_kwargs: learning rate scheduler kwargs + """ def __init__(self, hparams, **kwargs): super().__init__() @@ -65,6 +74,11 @@ class ProtoTorchBolt(pl.LightningModule): class PrototypeModel(ProtoTorchBolt): + """Abstract Prototype Model + + kwargs: + - distance_fn: distance function + """ proto_layer: AbstractComponents def __init__(self, hparams, **kwargs): @@ -203,35 +217,3 @@ class SupervisedPrototypeModel(PrototypeModel): accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) self.log("test_acc", accuracy) - - -class ProtoTorchMixin(object): - """All mixins are ProtoTorchMixins.""" - - -class NonGradientMixin(ProtoTorchMixin): - """Mixin for custom non-gradient optimization.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.automatic_optimization = False - - def training_step(self, train_batch, batch_idx, optimizer_idx=None): - raise NotImplementedError - - -class ImagePrototypesMixin(ProtoTorchMixin): - """Mixin for models with image prototypes.""" - proto_layer: Components - components: torch.Tensor - - def on_train_batch_end(self, outputs, batch, batch_idx): - """Constrain the components to the range [0, 1] by clamping after updates.""" - self.proto_layer.components.data.clamp_(0.0, 1.0) - - def get_prototype_grid(self, num_columns=2, return_channels_last=True): - from torchvision.utils import make_grid - grid = make_grid(self.components, nrow=num_columns) - if return_channels_last: - grid = grid.permute((1, 2, 0)) - return grid.cpu() diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index b58ccc1..7d53d4c 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -40,8 +40,8 @@ class PruneLoserPrototypes(pl.Callback): return None ratios = pl_module.prototype_win_ratios.mean(dim=0) - to_prune = torch.arange(len(ratios))[ratios < self.threshold] - to_prune = to_prune.tolist() + to_prune_tensor = torch.arange(len(ratios))[ratios < self.threshold] + to_prune = to_prune_tensor.tolist() prune_labels = pl_module.prototype_labels[to_prune] if self.prune_quota_per_epoch > 0: to_prune = to_prune[:self.prune_quota_per_epoch] diff --git a/prototorch/models/cbc.py b/prototorch/models/cbc.py index 2e38394..11c3804 100644 --- a/prototorch/models/cbc.py +++ b/prototorch/models/cbc.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F import torchmetrics from prototorch.core.competitions import CBCC from prototorch.core.components import ReasoningComponents @@ -7,12 +8,13 @@ from prototorch.core.losses import MarginLoss from prototorch.core.similarities import euclidean_similarity from prototorch.nn.wrappers import LambdaLayer -from .abstract import ImagePrototypesMixin from .glvq import SiameseGLVQ +from .mixins import ImagePrototypesMixin class CBC(SiameseGLVQ): """Classification-By-Components.""" + proto_layer: ReasoningComponents def __init__(self, hparams, **kwargs): super().__init__(hparams, skip_proto_layer=True, **kwargs) @@ -22,7 +24,7 @@ class CBC(SiameseGLVQ): reasonings_initializer = kwargs.get("reasonings_initializer", RandomReasoningsInitializer()) self.components_layer = ReasoningComponents( - self.hparams.distribution, + self.hparams["distribution"], components_initializer=components_initializer, reasonings_initializer=reasonings_initializer, ) @@ -32,7 +34,7 @@ class CBC(SiameseGLVQ): # Namespace hook self.proto_layer = self.components_layer - self.loss = MarginLoss(self.hparams.margin) + self.loss = MarginLoss(self.hparams["margin"]) def forward(self, x): components, reasonings = self.components_layer() @@ -48,7 +50,7 @@ class CBC(SiameseGLVQ): x, y = batch y_pred = self(x) num_classes = self.num_classes - y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes) + y_true = F.one_hot(y.long(), num_classes=num_classes) loss = self.loss(y_pred, y_true).mean() return y_pred, loss diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 81ddaef..7f85f2a 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -17,8 +17,9 @@ from prototorch.core.transforms import LinearTransform from prototorch.nn.wrappers import LambdaLayer, LossLayer from torch.nn.parameter import Parameter -from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel +from .abstract import SupervisedPrototypeModel from .extras import ltangent_distance, orthogonalization +from .mixins import ImagePrototypesMixin class GLVQ(SupervisedPrototypeModel): @@ -46,19 +47,24 @@ class GLVQ(SupervisedPrototypeModel): def initialize_prototype_win_ratios(self): self.register_buffer( "prototype_win_ratios", - torch.zeros(self.num_prototypes, device=self.device)) + torch.zeros(self.num_prototypes, device=self.device), + ) def on_train_epoch_start(self): self.initialize_prototype_win_ratios() def log_prototype_win_ratios(self, distances): batch_size = len(distances) - prototype_wc = torch.zeros(self.num_prototypes, - dtype=torch.long, - device=self.device) - wi, wc = torch.unique(distances.min(dim=-1).indices, - sorted=True, - return_counts=True) + prototype_wc = torch.zeros( + self.num_prototypes, + dtype=torch.long, + device=self.device, + ) + wi, wc = torch.unique( + distances.min(dim=-1).indices, + sorted=True, + return_counts=True, + ) prototype_wc[wi] = wc prototype_wr = prototype_wc / batch_size self.prototype_win_ratios = torch.vstack([ @@ -81,14 +87,12 @@ class GLVQ(SupervisedPrototypeModel): return train_loss def validation_step(self, batch, batch_idx): - # `model.eval()` and `torch.no_grad()` handled by pl out, val_loss = self.shared_step(batch, batch_idx) self.log("val_loss", val_loss) self.log_acc(out, batch[-1], tag="val_acc") return val_loss def test_step(self, batch, batch_idx): - # `model.eval()` and `torch.no_grad()` handled by pl out, test_loss = self.shared_step(batch, batch_idx) self.log_acc(out, batch[-1], tag="test_acc") return test_loss @@ -99,10 +103,6 @@ class GLVQ(SupervisedPrototypeModel): test_loss += batch_loss.item() self.log("test_loss", test_loss) - # TODO - # def predict_step(self, batch, batch_idx, dataloader_idx=None): - # pass - class SiameseGLVQ(GLVQ): """GLVQ in a Siamese setting. @@ -113,19 +113,23 @@ class SiameseGLVQ(GLVQ): """ - def __init__(self, - hparams, - backbone=torch.nn.Identity(), - both_path_gradients=False, - **kwargs): + def __init__( + self, + hparams, + backbone=torch.nn.Identity(), + both_path_gradients=False, + **kwargs, + ): distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance) super().__init__(hparams, distance_fn=distance_fn, **kwargs) self.backbone = backbone self.both_path_gradients = both_path_gradients def configure_optimizers(self): - proto_opt = self.optimizer(self.proto_layer.parameters(), - lr=self.hparams["proto_lr"]) + proto_opt = self.optimizer( + self.proto_layer.parameters(), + lr=self.hparams["proto_lr"], + ) # Only add a backbone optimizer if backbone has trainable parameters bb_params = list(self.backbone.parameters()) if (bb_params): @@ -266,13 +270,19 @@ class GMLVQ(GLVQ): super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - omega_initializer = kwargs.get("omega_initializer", - EyeLinearTransformInitializer()) - omega = omega_initializer.generate(self.hparams["input_dim"], - self.hparams["latent_dim"]) + omega_initializer = kwargs.get( + "omega_initializer", + EyeLinearTransformInitializer(), + ) + omega = omega_initializer.generate( + self.hparams["input_dim"], + self.hparams["latent_dim"], + ) self.register_parameter("_omega", Parameter(omega)) - self.backbone = LambdaLayer(lambda x: x @ self._omega, - name="omega matrix") + self.backbone = LambdaLayer( + lambda x: x @ self._omega, + name="omega matrix", + ) @property def omega_matrix(self): diff --git a/prototorch/models/lvq.py b/prototorch/models/lvq.py index aa893a2..6daca06 100644 --- a/prototorch/models/lvq.py +++ b/prototorch/models/lvq.py @@ -1,20 +1,21 @@ """LVQ models that are optimized using non-gradient methods.""" import logging +from collections import OrderedDict from prototorch.core.losses import _get_dp_dm from prototorch.nn.activations import get_activation from prototorch.nn.wrappers import LambdaLayer -from .abstract import NonGradientMixin from .glvq import GLVQ +from .mixins import NonGradientMixin class LVQ1(NonGradientMixin, GLVQ): """Learning Vector Quantization 1.""" def training_step(self, train_batch, batch_idx, optimizer_idx=None): - protos, plables = self.proto_layer() + protos, plabels = self.proto_layer() x, y = train_batch dis = self.compute_distances(x) # TODO Vectorized implementation @@ -28,9 +29,11 @@ class LVQ1(NonGradientMixin, GLVQ): else: shift = protos[w] - xi updated_protos = protos + 0.0 - updated_protos[w] = protos[w] + (self.hparams.lr * shift) - self.proto_layer.load_state_dict({"_components": updated_protos}, - strict=False) + updated_protos[w] = protos[w] + (self.hparams["lr"] * shift) + self.proto_layer.load_state_dict( + OrderedDict(_components=updated_protos), + strict=False, + ) logging.debug(f"dis={dis}") logging.debug(f"y={y}") @@ -58,10 +61,12 @@ class LVQ21(NonGradientMixin, GLVQ): shiftp = xi - protos[wp] shiftn = protos[wn] - xi updated_protos = protos + 0.0 - updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp) - updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn) - self.proto_layer.load_state_dict({"_components": updated_protos}, - strict=False) + updated_protos[wp] = protos[wp] + (self.hparams["lr"] * shiftp) + updated_protos[wn] = protos[wn] + (self.hparams["lr"] * shiftn) + self.proto_layer.load_state_dict( + OrderedDict(_components=updated_protos), + strict=False, + ) # Logging self.log_acc(dis, y, tag="train_acc") @@ -80,14 +85,17 @@ class MedianLVQ(NonGradientMixin, GLVQ): super().__init__(hparams, **kwargs) self.transfer_layer = LambdaLayer( - get_activation(self.hparams.transfer_fn)) + get_activation(self.hparams["transfer_fn"])) def _f(self, x, y, protos, plabels): d = self.distance_layer(x, protos) - dp, dm = _get_dp_dm(d, y, plabels) + dp, dm = _get_dp_dm(d, y, plabels, with_indices=False) mu = (dp - dm) / (dp + dm) - invmu = -1.0 * mu - f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0 + negative_mu = -1.0 * mu + f = self.transfer_layer( + negative_mu, + beta=self.hparams["transfer_beta"], + ) + 1.0 return f def expectation(self, x, y, protos, plabels): @@ -118,8 +126,10 @@ class MedianLVQ(NonGradientMixin, GLVQ): _lower_bound = self.lower_bound(x, y, _protos, plabels, gamma) if _lower_bound > lower_bound: logging.debug(f"Updating prototype {i} to data {k}...") - self.proto_layer.load_state_dict({"_components": _protos}, - strict=False) + self.proto_layer.load_state_dict( + OrderedDict(_components=_protos), + strict=False, + ) break # Logging diff --git a/prototorch/models/mixins.py b/prototorch/models/mixins.py new file mode 100644 index 0000000..7599f20 --- /dev/null +++ b/prototorch/models/mixins.py @@ -0,0 +1,35 @@ +import pytorch_lightning as pl +import torch +from prototorch.core.components import Components + + +class ProtoTorchMixin(pl.LightningModule): + """All mixins are ProtoTorchMixins.""" + + +class NonGradientMixin(ProtoTorchMixin): + """Mixin for custom non-gradient optimization.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.automatic_optimization = False + + def training_step(self, train_batch, batch_idx, optimizer_idx=None): + raise NotImplementedError + + +class ImagePrototypesMixin(ProtoTorchMixin): + """Mixin for models with image prototypes.""" + proto_layer: Components + components: torch.Tensor + + def on_train_batch_end(self, outputs, batch, batch_idx): + """Constrain the components to the range [0, 1] by clamping after updates.""" + self.proto_layer.components.data.clamp_(0.0, 1.0) + + def get_prototype_grid(self, num_columns=2, return_channels_last=True): + from torchvision.utils import make_grid + grid = make_grid(self.components, nrow=num_columns) + if return_channels_last: + grid = grid.permute((1, 2, 0)) + return grid.cpu() diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index 8833de2..c2a5105 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -6,9 +6,10 @@ from prototorch.core.competitions import wtac from prototorch.core.distances import squared_euclidean_distance from prototorch.core.losses import NeuralGasEnergy -from .abstract import NonGradientMixin, UnsupervisedPrototypeModel +from .abstract import UnsupervisedPrototypeModel from .callbacks import GNGCallback from .extras import ConnectionTopology +from .mixins import NonGradientMixin class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):