From f6994dfd837d30bce5a59a0a62162184fca96062 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 21 Apr 2021 14:51:34 +0200 Subject: [PATCH] Add glvq model --- prototorch/models/glvq.py | 57 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 prototorch/models/glvq.py diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py new file mode 100644 index 0000000..e04cd75 --- /dev/null +++ b/prototorch/models/glvq.py @@ -0,0 +1,57 @@ +import pytorch_lightning as pl +import torch +from prototorch.functions.competitions import wtac +from prototorch.functions.distances import euclidean_distance +from prototorch.functions.initializers import get_initializer +from prototorch.functions.losses import glvq_loss +from prototorch.modules.prototypes import Prototypes1D + + +class GLVQ(pl.LightningModule): + """Generalized Learning Vector Quantization.""" + def __init__(self, lr=1e-3, **kwargs): + super().__init__() + self.lr = lr + self.proto_layer = Prototypes1D(**kwargs) + + @property + def prototypes(self): + return self.proto_layer.prototypes.detach().numpy() + + @property + def prototype_labels(self): + return self.proto_layer.prototype_labels.detach().numpy() + + def forward(self, x): + protos = self.proto_layer.prototypes + dis = euclidean_distance(x, protos) + return dis + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + def training_step(self, train_batch, batch_idx): + x, y = train_batch + x = x.view(x.size(0), -1) + dis = self(x) + plabels = self.proto_layer.prototype_labels + mu = glvq_loss(dis, y, prototype_labels=plabels) + loss = mu.sum(dim=0) + self.log("train_loss", loss) + return loss + + def predict(self, x): + with torch.no_grad(): + d = self(x) + plabels = self.proto_layer.prototype_labels + y_pred = wtac(d, plabels) + return y_pred.numpy() + + +class ImageGLVQ(GLVQ): + """GLVQ model that constrains the prototypes to the range [0, 1] by + clamping after updates. + """ + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.proto_layer.prototypes.data.clamp_(0., 1.)