import argparse import pytorch_lightning as pl import torch import torchmetrics from prototorch.functions.competitions import wtac from prototorch.functions.distances import euclidean_distance from prototorch.functions.initializers import get_initializer from prototorch.functions.losses import glvq_loss from prototorch.modules.prototypes import Prototypes1D class GLVQ(pl.LightningModule): """Generalized Learning Vector Quantization.""" def __init__(self, hparams, input_dim, nclasses, **kwargs): super().__init__() self.lr = hparams.lr self.hparams = hparams # self.save_hyperparameters( # "lr", # "prototypes_per_class", # "prototype_initializer", # ) self.proto_layer = Prototypes1D( input_dim=input_dim, nclasses=nclasses, prototypes_per_class=hparams.prototypes_per_class, prototype_initializer=hparams.prototype_initializer, **kwargs) self.train_acc = torchmetrics.Accuracy() @property def prototypes(self): return self.proto_layer.prototypes.detach().numpy() @property def prototype_labels(self): return self.proto_layer.prototype_labels.detach().numpy() def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer @staticmethod def add_model_specific_args(parent_parser): parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-2) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--prototypes_per_class", type=int, default=1) parser.add_argument("--prototype_initializer", type=str, default="zeros") return parser def forward(self, x): protos = self.proto_layer.prototypes dis = euclidean_distance(x, protos) return dis def training_step(self, train_batch, batch_idx): x, y = train_batch x = x.view(x.size(0), -1) dis = self(x) plabels = self.proto_layer.prototype_labels mu = glvq_loss(dis, y, prototype_labels=plabels) loss = mu.sum(dim=0) self.log("train_loss", loss) with torch.no_grad(): preds = wtac(dis, plabels) # self.train_acc.update(preds.int(), y.int()) self.train_acc( preds.int(), y.int()) # FloatTensors are assumed to be class probabilities self.log("Training Accuracy", self.train_acc, on_step=False, on_epoch=True) return loss # def training_epoch_end(self, outs): # # Calling `self.train_acc.compute()` is # # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)` # self.log("train_acc_epoch", self.train_acc.compute()) def predict(self, x): with torch.no_grad(): # model.eval() # ?! d = self(x) plabels = self.proto_layer.prototype_labels y_pred = wtac(d, plabels) return y_pred.numpy() class ImageGLVQ(GLVQ): """GLVQ model that constrains the prototypes to the range [0, 1] by clamping after updates. """ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.proto_layer.prototypes.data.clamp_(0., 1.)