From fa928afe2ca6a0f04c2ff6f397be8b59009ae90b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 1 Sep 2021 10:49:57 +0200 Subject: [PATCH] feat(vis): 2D EV projection for GMLVQ --- examples/gmlvq_iris.py | 58 +++++++++++++++++++++++++++++++++++++++ prototorch/models/glvq.py | 6 ++++ prototorch/models/vis.py | 33 ++++++++++++++++++++++ 3 files changed, 97 insertions(+) create mode 100644 examples/gmlvq_iris.py diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py new file mode 100644 index 0000000..f326f8b --- /dev/null +++ b/examples/gmlvq_iris.py @@ -0,0 +1,58 @@ +"""GMLVQ example using the Iris dataset.""" + +import argparse + +import prototorch as pt +import pytorch_lightning as pl +import torch +from torch.optim.lr_scheduler import ExponentialLR + +if __name__ == "__main__": + # Command-line arguments + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # Dataset + train_ds = pt.datasets.Iris() + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) + + # Hyperparameters + hparams = dict( + input_dim=4, + latent_dim=4, + distribution={ + "num_classes": 3, + "per_class": 2 + }, + proto_lr=0.01, + bb_lr=0.01, + ) + + # Initialize the model + model = pt.models.GMLVQ( + hparams, + optimizer=torch.optim.Adam, + prototypes_initializer=pt.initializers.SMCI(train_ds), + lr_scheduler=ExponentialLR, + lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), + ) + + # Compute intermediate input and output sizes + model.example_input_array = torch.zeros(4, 4) + + # Callbacks + vis = pt.models.VisGMLVQ2D(data=train_ds) + + # Setup trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=[vis], + weights_summary="full", + accelerator="ddp", + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index ce4e1d5..c781784 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -251,6 +251,12 @@ class GMLVQ(GLVQ): def omega_matrix(self): return self._omega.detach().cpu() + @property + def lambda_matrix(self): + omega = self._omega.detach() # (input_dim, latent_dim) + lam = omega @ omega.T + return lam.detach().cpu() + def compute_distances(self, x): protos, _ = self.proto_layer() distances = self.distance_layer(x, protos, self._omega) diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 0f39f1a..9744a9d 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -178,6 +178,39 @@ class VisSiameseGLVQ2D(Vis2DAbstract): self.log_and_display(trainer, pl_module) +class VisGMLVQ2D(Vis2DAbstract): + def __init__(self, *args, ev_proj=True, **kwargs): + super().__init__(*args, **kwargs) + self.ev_proj = ev_proj + + def on_epoch_end(self, trainer, pl_module): + if not self.precheck(trainer): + return True + + protos = pl_module.prototypes + plabels = pl_module.prototype_labels + x_train, y_train = self.x_train, self.y_train + device = pl_module.device + omega = pl_module._omega.detach() + lam = omega @ omega.T + u, _, _ = torch.pca_lowrank(lam, q=2) + with torch.no_grad(): + x_train = torch.Tensor(x_train).to(device) + x_train = x_train @ u + x_train = x_train.cpu().detach() + if self.show_protos: + with torch.no_grad(): + protos = torch.Tensor(protos).to(device) + protos = protos @ u + protos = protos.cpu().detach() + ax = self.setup_ax() + self.plot_data(ax, x_train, y_train) + if self.show_protos: + self.plot_protos(ax, protos, plabels) + + self.log_and_display(trainer, pl_module) + + class VisCBC2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): if not self.precheck(trainer):