From 8bad54fc2d7599a981775ddb6f3593fefa41e49d Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 29 Apr 2021 17:11:06 +0200 Subject: [PATCH 01/27] Small fix on example script --- examples/glvq_iris_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/glvq_iris_v1.py b/examples/glvq_iris_v1.py index aff2d26..893e177 100644 --- a/examples/glvq_iris_v1.py +++ b/examples/glvq_iris_v1.py @@ -28,7 +28,7 @@ if __name__ == "__main__": ) # Initialize the model - model = GLVQ(hparams, data=[x_train, y_train]) + model = GLVQ(hparams) # Setup trainer trainer = pl.Trainer( From fef73e2fbf0fe983845d3a37e9f5fa35934c5a5b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 29 Apr 2021 19:09:10 +0200 Subject: [PATCH 02/27] [BUG] NaN when training with selection initializer How to reproduce: Run the `glvq_spiral.py` file under `examples/`. The error seems to occur when using a lot of prototypes in combination with the `StratifiedSelectionInitializer`. Using only a prototype per class, or using another initializer like the `StratifiedMeanInitializer` seems to make the problem go away. --- examples/glvq_spiral.py | 56 +++++++++++++++++++ prototorch/models/callbacks/visualization.py | 57 +++++++++----------- 2 files changed, 82 insertions(+), 31 deletions(-) create mode 100644 examples/glvq_spiral.py diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py new file mode 100644 index 0000000..a2aebcc --- /dev/null +++ b/examples/glvq_spiral.py @@ -0,0 +1,56 @@ +"""GLVQ example using the spiral dataset.""" + +import pytorch_lightning as pl +import torch +from prototorch.components import initializers as cinit +from prototorch.datasets.abstract import NumpyDataset +from prototorch.datasets.spiral import make_spiral +from prototorch.models.callbacks.visualization import VisGLVQ2D +from prototorch.models.glvq import GLVQ +from torch.utils.data import DataLoader + + +class StopOnNaN(pl.Callback): + def __init__(self, param): + super().__init__() + self.param = param + + def on_epoch_end(self, trainer, pl_module, logs={}): + if torch.isnan(self.param).any(): + raise ValueError("NaN encountered. Stopping.") + + +if __name__ == "__main__": + # Dataset + x_train, y_train = make_spiral(n_samples=600, noise=0.6) + train_ds = NumpyDataset(x_train, y_train) + + # Dataloaders + train_loader = DataLoader(train_ds, num_workers=0, batch_size=256) + + # Hyperparameters + hparams = dict( + nclasses=2, + prototypes_per_class=20, + # prototype_initializer=cinit.SSI(torch.Tensor(x_train), + prototype_initializer=cinit.SMI(torch.Tensor(x_train), + torch.Tensor(y_train)), + lr=0.01, + ) + + # Initialize the model + model = GLVQ(hparams) + + # Callbacks + vis = VisGLVQ2D(x_train, y_train) + # vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True) + snan = StopOnNaN(model.proto_layer.components) + + # Setup trainer + trainer = pl.Trainer( + max_epochs=200, + callbacks=[vis, snan], + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/callbacks/visualization.py b/prototorch/models/callbacks/visualization.py index 11bc729..4a946a5 100644 --- a/prototorch/models/callbacks/visualization.py +++ b/prototorch/models/callbacks/visualization.py @@ -261,20 +261,29 @@ class VisPointProtos(VisWeights): self._show_and_save(epoch) -class VisGLVQ2D(pl.Callback): +class Vis2DAbstract(pl.Callback): def __init__(self, x_train, y_train, title="Prototype Visualization", - cmap="viridis"): + cmap="viridis", + show_last_only=False, + block=False): super().__init__() self.x_train = x_train self.y_train = y_train self.title = title self.fig = plt.figure(self.title) self.cmap = cmap + self.show_last_only = show_last_only + self.block = block + +class VisGLVQ2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): + if self.show_last_only: + if trainer.current_epoch != trainer.max_epochs - 1: + return protos = pl_module.prototypes plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train @@ -306,22 +315,13 @@ class VisGLVQ2D(pl.Callback): ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.set_xlim(left=x_min + 0, right=x_max - 0) ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) + if not self.block: + plt.pause(0.01) + else: + plt.show(block=True) -class VisSiameseGLVQ2D(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Prototype Visualization", - cmap="viridis"): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - +class VisSiameseGLVQ2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): protos = pl_module.prototypes plabels = pl_module.prototype_labels @@ -361,22 +361,14 @@ class VisSiameseGLVQ2D(pl.Callback): global_step=trainer.current_epoch, close=False, ) - plt.pause(0.1) + + if not self.block: + plt.pause(0.01) + else: + plt.show(block=True) -class VisNG2D(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Neural Gas Visualization", - cmap="viridis"): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - +class VisNG2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): protos = pl_module.prototypes cmat = pl_module.topology_layer.cmat.cpu().numpy() @@ -410,4 +402,7 @@ class VisNG2D(pl.Callback): "k-", ) - plt.pause(0.01) + if not self.block: + plt.pause(0.01) + else: + plt.show(block=True) From ccaa52c408d375625c210ff728f884a1b4d5f334 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 29 Apr 2021 19:14:33 +0200 Subject: [PATCH 03/27] Add missing abstract.py file --- prototorch/models/abstract.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 prototorch/models/abstract.py diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py new file mode 100644 index 0000000..65ce8ae --- /dev/null +++ b/prototorch/models/abstract.py @@ -0,0 +1,14 @@ +import pytorch_lightning as pl +import torch + + +class AbstractLightningModel(pl.LightningModule): + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) + return optimizer + + +class AbstractPrototypeModel(AbstractLightningModel): + @property + def prototypes(self): + return self.proto_layer.components.detach().numpy() From e44516fc49638287b23f2cedd8638115eeade2e1 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 29 Apr 2021 19:25:08 +0200 Subject: [PATCH 04/27] Update example script --- examples/glvq_spiral.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py index a2aebcc..a2d3571 100644 --- a/examples/glvq_spiral.py +++ b/examples/glvq_spiral.py @@ -32,9 +32,9 @@ if __name__ == "__main__": hparams = dict( nclasses=2, prototypes_per_class=20, - # prototype_initializer=cinit.SSI(torch.Tensor(x_train), - prototype_initializer=cinit.SMI(torch.Tensor(x_train), - torch.Tensor(y_train)), + prototype_initializer=cinit.SSI(torch.Tensor(x_train), + torch.Tensor(y_train), + noise=1e-7), lr=0.01, ) @@ -42,8 +42,7 @@ if __name__ == "__main__": model = GLVQ(hparams) # Callbacks - vis = VisGLVQ2D(x_train, y_train) - # vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True) + vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True) snan = StopOnNaN(model.proto_layer.components) # Setup trainer From db7bb7619fcec8efeba2c843f634303714dce8ee Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 29 Apr 2021 22:36:10 +0200 Subject: [PATCH 05/27] Add border argument in visualization callback --- prototorch/models/callbacks/visualization.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/prototorch/models/callbacks/visualization.py b/prototorch/models/callbacks/visualization.py index 4a946a5..8692eff 100644 --- a/prototorch/models/callbacks/visualization.py +++ b/prototorch/models/callbacks/visualization.py @@ -267,6 +267,7 @@ class Vis2DAbstract(pl.Callback): y_train, title="Prototype Visualization", cmap="viridis", + border=1, show_last_only=False, block=False): super().__init__() @@ -275,6 +276,7 @@ class Vis2DAbstract(pl.Callback): self.title = title self.fig = plt.figure(self.title) self.cmap = cmap + self.border = border self.show_last_only = show_last_only self.block = block @@ -343,8 +345,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract): s=50, ) x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 + x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border + y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), np.arange(y_min, y_max, 1 / 50)) mesh_input = np.c_[xx.ravel(), yy.ravel()] From 6dd9b1492c1d64ab93b64ee069b1eeb3090b865f Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 29 Apr 2021 23:37:22 +0200 Subject: [PATCH 06/27] Add more models --- README.md | 10 ++-- prototorch/models/glvq.py | 96 +++++++++++++++++++++++++++++---------- 2 files changed, 80 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 6b91b99..42450b3 100644 --- a/README.md +++ b/README.md @@ -48,13 +48,17 @@ To assist in the development process, you may also find it useful to install - Neural Gas ## Work in Progress + - CBC +- LVQMLN +- GMLVQ +- Limited-Rank GMLVQ ## Planned models -- GMLVQ + - Local-Matrix GMLVQ -- Limited-Rank GMLVQ - GTLVQ - RSLVQ - PLVQ -- LVQMLN +- SILVQ +- KNN diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 61749d7..4e13cf2 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -37,32 +37,28 @@ class GLVQ(AbstractPrototypeModel): def training_step(self, train_batch, batch_idx): x, y = train_batch - x = x.view(x.size(0), -1) + x = x.view(x.size(0), -1) # flatten dis = self(x) plabels = self.proto_layer.component_labels mu = glvq_loss(dis, y, prototype_labels=plabels) loss = mu.sum(dim=0) - self.log("train_loss", loss) + + # Compute training accuracy with torch.no_grad(): preds = wtac(dis, plabels) - # self.train_acc.update(preds.int(), y.int()) - self.train_acc( - preds.int(), - y.int()) # FloatTensors are assumed to be class probabilities - self.log( - "acc", - self.train_acc, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return loss + # `.int()` because FloatTensors are assumed to be class probabilities + self.train_acc(preds.int(), y.int()) - # def training_epoch_end(self, outs): - # # Calling `self.train_acc.compute()` is - # # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)` - # self.log("train_acc_epoch", self.train_acc.compute()) + # Logging + self.log("train_loss", loss) + self.log("acc", + self.train_acc, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True) + + return loss def predict(self, x): # model.eval() # ?! @@ -76,8 +72,9 @@ class GLVQ(AbstractPrototypeModel): class ImageGLVQ(GLVQ): """GLVQ for training on image data. - GLVQ model that constrains the prototypes to the range [0, 1] by - clamping after updates. + GLVQ model that constrains the prototypes to the range [0, 1] by clamping + after updates. + """ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.proto_layer.components.data.clamp_(0.0, 1.0) @@ -89,6 +86,7 @@ class SiameseGLVQ(GLVQ): GLVQ model that applies an arbitrary transformation on the inputs and the prototypes before computing the distances between them. The weights in the transformation pipeline are only learned from the inputs. + """ def __init__(self, hparams, @@ -107,14 +105,18 @@ class SiameseGLVQ(GLVQ): def forward(self, x): self.sync_backbones() protos, _ = self.proto_layer() - latent_x = self.backbone(x) latent_protos = self.backbone_dependent(protos) - dis = euclidean_distance(latent_x, latent_protos) return dis def predict_latent(self, x): + """Predict `x` assuming it is already embedded in the latent space. + + Only the prototypes are embedded in the latent space using the + backbone. + + """ # model.eval() # ?! with torch.no_grad(): protos, plabels = self.proto_layer() @@ -122,3 +124,51 @@ class SiameseGLVQ(GLVQ): d = euclidean_distance(x, latent_protos) y_pred = wtac(d, plabels) return y_pred.numpy() + + +class GMLVQ(GLVQ): + """Generalized Matrix Learning Vector Quantization.""" + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.omega_layer = torch.nn.Linear(self.hparams.input_dim, + self.latent_dim, + bias=False) + + def forward(self, x): + protos, _ = self.proto_layer() + latent_x = self.omega_layer(x) + latent_protos = self.omega_layer(protos) + dis = euclidean_distance(latent_x, latent_protos) + return dis + + +class LVQMLN(GLVQ): + """Learning Vector Quantization Multi-Layer Network. + + GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT + on the prototypes before computing the distances between them. This of + course, means that the prototypes no longer live the input space, but + rather in the embedding space. + + """ + def __init__(self, + hparams, + backbone_module=torch.nn.Identity, + backbone_params={}, + **kwargs): + super().__init__(hparams, **kwargs) + self.backbone = backbone_module(**backbone_params) + + def forward(self, x): + latent_protos, _ = self.proto_layer() + latent_x = self.backbone(x) + dis = euclidean_distance(latent_x, latent_protos) + return dis + + def predict_latent(self, x): + """Predict `x` assuming it is already embedded in the latent space.""" + with torch.no_grad(): + latent_protos, plabels = self.proto_layer() + d = euclidean_distance(x, latent_protos) + y_pred = wtac(d, plabels) + return y_pred.numpy() From 042b3fcaa2b796afcdb80a688dcd72434f366463 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 3 May 2021 13:19:23 +0200 Subject: [PATCH 07/27] Add tensorboard argument to visualization callbacks --- prototorch/models/callbacks/visualization.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/prototorch/models/callbacks/visualization.py b/prototorch/models/callbacks/visualization.py index 8692eff..98a1889 100644 --- a/prototorch/models/callbacks/visualization.py +++ b/prototorch/models/callbacks/visualization.py @@ -268,6 +268,7 @@ class Vis2DAbstract(pl.Callback): title="Prototype Visualization", cmap="viridis", border=1, + tensorboard=False, show_last_only=False, block=False): super().__init__() @@ -277,9 +278,17 @@ class Vis2DAbstract(pl.Callback): self.fig = plt.figure(self.title) self.cmap = cmap self.border = border + self.tensorboard = tensorboard self.show_last_only = show_last_only self.block = block + def add_to_tensorboard(self, trainer, pl_module): + tb = pl_module.logger.experiment + tb.add_figure(tag=f"{self.title}", + figure=self.fig, + global_step=trainer.current_epoch, + close=False) + class VisGLVQ2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): @@ -317,6 +326,8 @@ class VisGLVQ2D(Vis2DAbstract): ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.set_xlim(left=x_min + 0, right=x_max - 0) ax.set_ylim(bottom=y_min + 0, top=y_max - 0) + if self.tensorboard: + self.add_to_tensorboard(trainer, pl_module) if not self.block: plt.pause(0.01) else: @@ -364,6 +375,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract): close=False, ) + if self.tensorboard: + self.add_to_tensorboard(trainer, pl_module) if not self.block: plt.pause(0.01) else: @@ -404,6 +417,8 @@ class VisNG2D(Vis2DAbstract): "k-", ) + if self.tensorboard: + self.add_to_tensorboard(trainer, pl_module) if not self.block: plt.pause(0.01) else: From 96aeaa3448e1c7efba3a9fdd56e5d6a90f87a4a7 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 3 May 2021 13:20:49 +0200 Subject: [PATCH 08/27] Add support for multiple optimizers --- prototorch/models/abstract.py | 11 ++++++++++- prototorch/models/glvq.py | 10 +++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index 65ce8ae..dcc89a9 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -1,11 +1,20 @@ import pytorch_lightning as pl import torch +from torch.optim.lr_scheduler import ExponentialLR class AbstractLightningModel(pl.LightningModule): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) - return optimizer + scheduler = ExponentialLR(optimizer, + gamma=0.99, + last_epoch=-1, + verbose=False) + sch = { + "scheduler": scheduler, + "interval": "step", + } # called after each training step + return [optimizer], [sch] class AbstractPrototypeModel(AbstractLightningModel): diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 4e13cf2..33a349c 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -19,6 +19,7 @@ class GLVQ(AbstractPrototypeModel): # Default Values self.hparams.setdefault("distance", euclidean_distance) + self.hparams.setdefault("optimizer", torch.optim.Adam) self.proto_layer = LabeledComponents( labels=(self.hparams.nclasses, self.hparams.prototypes_per_class), @@ -35,7 +36,7 @@ class GLVQ(AbstractPrototypeModel): dis = self.hparams.distance(x, protos) return dis - def training_step(self, train_batch, batch_idx): + def training_step(self, train_batch, batch_idx, optimizer_idx=None): x, y = train_batch x = x.view(x.size(0), -1) # flatten dis = self(x) @@ -102,6 +103,13 @@ class SiameseGLVQ(GLVQ): master_state = self.backbone.state_dict() self.backbone_dependent.load_state_dict(master_state, strict=True) + def configure_optimizers(self): + optim = self.hparams.optimizer + proto_opt = optim(self.proto_layer.parameters(), + lr=self.hparams.proto_lr) + bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr) + return proto_opt, bb_opt + def forward(self, x): self.sync_backbones() protos, _ = self.proto_layer() From d8e017ae74e69d6a5c980f476b2e2a2cb5c3d54f Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 3 May 2021 16:09:22 +0200 Subject: [PATCH 09/27] Update SiameseGLVQ --- prototorch/models/glvq.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 33a349c..3a17f69 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -107,8 +107,13 @@ class SiameseGLVQ(GLVQ): optim = self.hparams.optimizer proto_opt = optim(self.proto_layer.parameters(), lr=self.hparams.proto_lr) - bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr) - return proto_opt, bb_opt + if list(self.backbone.parameters()): + # only add an optimizer is the backbone has trainable parameters + # otherwise, the next line fails + bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr) + return proto_opt, bb_opt + else: + return proto_opt def forward(self, x): self.sync_backbones() From a1ac5a70c7d19fe0e1eb739a6227a6929279c4b0 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 4 May 2021 14:34:00 +0200 Subject: [PATCH 10/27] Use squared euclidean distance in GMLVQ --- prototorch/models/glvq.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 3a17f69..aa4f5ae 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -3,7 +3,8 @@ import torch import torchmetrics from prototorch.components import LabeledComponents from prototorch.functions.competitions import wtac -from prototorch.functions.distances import euclidean_distance +from prototorch.functions.distances import (euclidean_distance, + squared_euclidean_distance) from prototorch.functions.losses import glvq_loss from prototorch.modules.prototypes import Prototypes1D @@ -151,7 +152,7 @@ class GMLVQ(GLVQ): protos, _ = self.proto_layer() latent_x = self.omega_layer(x) latent_protos = self.omega_layer(protos) - dis = euclidean_distance(latent_x, latent_protos) + dis = squared_euclidean_distance(latent_x, latent_protos) return dis From f402eea88439ae6dce8ce8e6717f6336cb922c73 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 4 May 2021 15:11:16 +0200 Subject: [PATCH 11/27] Add GMLVQ examples --- README.md | 5 +++-- examples/gmlvq_iris.py | 47 +++++++++++++++++++++++++++++++++++++++ examples/gmlvq_tecator.py | 47 +++++++++++++++++++++++++++++++++++++++ prototorch/models/glvq.py | 22 ++++++++++++++++-- 4 files changed, 117 insertions(+), 4 deletions(-) create mode 100644 examples/gmlvq_iris.py create mode 100644 examples/gmlvq_tecator.py diff --git a/README.md b/README.md index 42450b3..e16667f 100644 --- a/README.md +++ b/README.md @@ -46,13 +46,13 @@ To assist in the development process, you may also find it useful to install - GLVQ - Siamese GLVQ - Neural Gas +- GMLVQ +- Limited-Rank GMLVQ ## Work in Progress - CBC - LVQMLN -- GMLVQ -- Limited-Rank GMLVQ ## Planned models @@ -62,3 +62,4 @@ To assist in the development process, you may also find it useful to install - PLVQ - SILVQ - KNN +- LVQ1 diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py new file mode 100644 index 0000000..7bee5fd --- /dev/null +++ b/examples/gmlvq_iris.py @@ -0,0 +1,47 @@ +"""GMLVQ example using all four dimensions of the Iris dataset.""" + +import pytorch_lightning as pl +import torch +from prototorch.components import initializers as cinit +from prototorch.datasets.abstract import NumpyDataset +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import GMLVQ +from sklearn.datasets import load_iris +from torch.utils.data import DataLoader + +if __name__ == "__main__": + # Dataset + x_train, y_train = load_iris(return_X_y=True) + train_ds = NumpyDataset(x_train, y_train) + + # Dataloaders + train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + + # Hyperparameters + hparams = dict( + nclasses=3, + prototypes_per_class=1, + prototype_initializer=cinit.SMI(torch.Tensor(x_train), + torch.Tensor(y_train)), + input_dim=x_train.shape[1], + latent_dim=2, + lr=0.01, + ) + + # Initialize the model + model = GMLVQ(hparams) + + # Model summary + print(model) + + # Callbacks + vis = VisSiameseGLVQ2D(x_train, y_train) + + # Namespace hook for the visualization to work + model.backbone = model.omega_layer + + # Setup trainer + trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) + + # Training loop + trainer.fit(model, train_loader) diff --git a/examples/gmlvq_tecator.py b/examples/gmlvq_tecator.py new file mode 100644 index 0000000..8e0cf0d --- /dev/null +++ b/examples/gmlvq_tecator.py @@ -0,0 +1,47 @@ +"""GMLVQ example using the Tecator dataset.""" + +import pytorch_lightning as pl +import torch +from prototorch.components import initializers as cinit +from prototorch.datasets.tecator import Tecator +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import GMLVQ +from torch.utils.data import DataLoader + +if __name__ == "__main__": + # Dataset + train_ds = Tecator(root="./datasets/", train=True) + + # Dataloaders + train_loader = DataLoader(train_ds, num_workers=0, batch_size=32) + + # Grab the full dataset to warm-start prototypes + x, y = next(iter(DataLoader(train_ds, batch_size=len(train_ds)))) + + # Hyperparameters + hparams = dict( + nclasses=2, + prototypes_per_class=2, + prototype_initializer=cinit.SMI(x, y), + input_dim=x.shape[1], + latent_dim=2, + lr=0.01, + ) + + # Initialize the model + model = GMLVQ(hparams) + + # Model summary + print(model) + + # Callbacks + vis = VisSiameseGLVQ2D(x, y) + + # Namespace hook for the visualization to work + model.backbone = model.omega_layer + + # Setup trainer + trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) + + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index aa4f5ae..6f851ca 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -94,11 +94,13 @@ class SiameseGLVQ(GLVQ): hparams, backbone_module=torch.nn.Identity, backbone_params={}, + sync=True, **kwargs): super().__init__(hparams, **kwargs) self.backbone = backbone_module(**backbone_params) self.backbone_dependent = backbone_module( **backbone_params).requires_grad_(False) + self.sync = sync def sync_backbones(self): master_state = self.backbone.state_dict() @@ -117,7 +119,8 @@ class SiameseGLVQ(GLVQ): return proto_opt def forward(self, x): - self.sync_backbones() + if self.sync: + self.sync_backbones() protos, _ = self.proto_layer() latent_x = self.backbone(x) latent_protos = self.backbone_dependent(protos) @@ -145,7 +148,7 @@ class GMLVQ(GLVQ): def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) self.omega_layer = torch.nn.Linear(self.hparams.input_dim, - self.latent_dim, + self.hparams.latent_dim, bias=False) def forward(self, x): @@ -155,6 +158,21 @@ class GMLVQ(GLVQ): dis = squared_euclidean_distance(latent_x, latent_protos) return dis + def predict_latent(self, x): + """Predict `x` assuming it is already embedded in the latent space. + + Only the prototypes are embedded in the latent space using the + backbone. + + """ + # model.eval() # ?! + with torch.no_grad(): + protos, plabels = self.proto_layer() + latent_protos = self.omega_layer(protos) + d = squared_euclidean_distance(x, latent_protos) + y_pred = wtac(d, plabels) + return y_pred.numpy() + class LVQMLN(GLVQ): """Learning Vector Quantization Multi-Layer Network. From d644114090f56aae095ca38c72519178940c4284 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 4 May 2021 20:56:16 +0200 Subject: [PATCH 12/27] Add loss transfer function to glvq --- examples/glvq_spiral.py | 2 ++ prototorch/models/glvq.py | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py index a2d3571..ccfa191 100644 --- a/examples/glvq_spiral.py +++ b/examples/glvq_spiral.py @@ -35,6 +35,8 @@ if __name__ == "__main__": prototype_initializer=cinit.SSI(torch.Tensor(x_train), torch.Tensor(y_train), noise=1e-7), + transfer_function="sigmoid_beta", + transfer_beta=10.0, lr=0.01, ) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 6f851ca..fa0a46a 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -2,6 +2,7 @@ import pytorch_lightning as pl import torch import torchmetrics from prototorch.components import LabeledComponents +from prototorch.functions.activations import get_activation from prototorch.functions.competitions import wtac from prototorch.functions.distances import (euclidean_distance, squared_euclidean_distance) @@ -21,11 +22,14 @@ class GLVQ(AbstractPrototypeModel): # Default Values self.hparams.setdefault("distance", euclidean_distance) self.hparams.setdefault("optimizer", torch.optim.Adam) + self.hparams.setdefault("transfer_function", "identity") + self.hparams.setdefault("transfer_beta", 10.0) self.proto_layer = LabeledComponents( labels=(self.hparams.nclasses, self.hparams.prototypes_per_class), initializer=self.hparams.prototype_initializer) + self.transfer_function = get_activation(self.hparams.transfer_function) self.train_acc = torchmetrics.Accuracy() @property @@ -43,7 +47,9 @@ class GLVQ(AbstractPrototypeModel): dis = self(x) plabels = self.proto_layer.component_labels mu = glvq_loss(dis, y, prototype_labels=plabels) - loss = mu.sum(dim=0) + batch_loss = self.transfer_function(mu, + beta=self.hparams.transfer_beta) + loss = batch_loss.sum(dim=0) # Compute training accuracy with torch.no_grad(): From 1c3613019b6a04f3613de94c5bef187e99d7e26f Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Thu, 6 May 2021 14:10:09 +0200 Subject: [PATCH 13/27] Update Examples to new initializer architecture. Visualization still borken for some examples. --- examples/cbc_circle.py | 37 ++------ examples/cbc_iris.py | 74 +++++++-------- examples/cbc_spiral.py | 4 +- examples/cbc_spiral_with_glvq_start.py | 7 +- examples/glvq_iris.py | 122 ++++--------------------- examples/glvq_iris_v1.py | 40 -------- examples/glvq_mnist.py | 5 +- examples/glvq_spiral.py | 3 +- examples/gmlvq_iris.py | 5 +- examples/gmlvq_tecator.py | 4 +- examples/ng_iris.py | 7 +- examples/siamese_glvq_iris.py | 10 +- prototorch/models/cbc.py | 17 ++-- prototorch/models/glvq.py | 3 - prototorch/models/neural_gas.py | 2 - 15 files changed, 92 insertions(+), 248 deletions(-) delete mode 100644 examples/glvq_iris_v1.py diff --git a/examples/cbc_circle.py b/examples/cbc_circle.py index 41702f7..1ed88a8 100644 --- a/examples/cbc_circle.py +++ b/examples/cbc_circle.py @@ -4,13 +4,12 @@ import numpy as np import pytorch_lightning as pl import torch from matplotlib import pyplot as plt +from prototorch.components import initializers as cinit +from prototorch.datasets.abstract import NumpyDataset from sklearn.datasets import make_circles from torch.utils.data import DataLoader -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.callbacks.visualization import VisPointProtos from prototorch.models.cbc import CBC, euclidean_similarity -from prototorch.models.glvq import GLVQ class VisualizationCallback(pl.Callback): @@ -32,7 +31,7 @@ class VisualizationCallback(pl.Callback): def on_epoch_end(self, trainer, pl_module): if self.prototype_model: - protos = pl_module.prototypes + protos = pl_module.components color = pl_module.prototype_labels else: protos = pl_module.components @@ -83,8 +82,8 @@ if __name__ == "__main__": hparams = dict( input_dim=x_train.shape[1], nclasses=len(np.unique(y_train)), - prototypes_per_class=5, - prototype_initializer="randn", + num_components=5, + component_initializer=cinit.RandomInitializer(x_train.shape[1]), lr=0.01, ) @@ -95,31 +94,15 @@ if __name__ == "__main__": similarity=euclidean_similarity, ) - model = GLVQ(hparams, data=[x_train, y_train]) - - # Fix the component locations - # model.proto_layer.requires_grad_(False) - - # import sys - # sys.exit() - - # Model summary - print(model) - # Callbacks - dvis = VisPointProtos( - data=(x_train, y_train), - save=True, - snap=False, - voronoi=True, - resolution=50, - pause_time=0.1, - make_gif=True, - ) + dvis = VisualizationCallback(x_train, + y_train, + prototype_model=False, + title="CBC Circle Example") # Setup trainer trainer = pl.Trainer( - max_epochs=10, + max_epochs=50, callbacks=[ dvis, ], diff --git a/examples/cbc_iris.py b/examples/cbc_iris.py index e5eab9e..5497e6c 100644 --- a/examples/cbc_iris.py +++ b/examples/cbc_iris.py @@ -4,30 +4,38 @@ import numpy as np import pytorch_lightning as pl import torch from matplotlib import pyplot as plt +from prototorch.components import initializers as cinit +from prototorch.datasets.abstract import NumpyDataset from sklearn.datasets import load_iris from torch.utils.data import DataLoader -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.cbc import CBC +from prototorch.models.cbc import CBC, euclidean_similarity class VisualizationCallback(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Prototype Visualization", - cmap="viridis"): + def __init__( + self, + x_train, + y_train, + prototype_model=True, + title="Prototype Visualization", + cmap="viridis", + ): super().__init__() self.x_train = x_train self.y_train = y_train self.title = title self.fig = plt.figure(self.title) self.cmap = cmap + self.prototype_model = prototype_model def on_epoch_end(self, trainer, pl_module): - # protos = pl_module.prototypes - protos = pl_module.components - # plabels = pl_module.prototype_labels + if self.prototype_model: + protos = pl_module.components + color = pl_module.prototype_labels + else: + protos = pl_module.components + color = "k" ax = self.fig.gca() ax.cla() ax.set_title(self.title) @@ -37,8 +45,7 @@ class VisualizationCallback(pl.Callback): ax.scatter( protos[:, 0], protos[:, 1], - # c=plabels, - c="k", + c=color, cmap=self.cmap, edgecolor="k", marker="D", @@ -71,44 +78,33 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( input_dim=x_train.shape[1], - nclasses=3, - prototypes_per_class=3, - prototype_initializer="stratified_mean", + nclasses=len(np.unique(y_train)), + num_components=9, + component_initializer=cinit.StratifiedMeanInitializer( + torch.Tensor(x_train), torch.Tensor(y_train)), lr=0.01, ) # Initialize the model - model = CBC(hparams, data=[x_train, y_train]) - - # Fix the component locations - # model.proto_layer.requires_grad_(False) - - # Pure-positive reasonings - ncomps = 3 - nclasses = 3 - rmat = torch.stack( - [0.9 * torch.eye(ncomps), - torch.zeros(ncomps, nclasses)], dim=0) - # model.reasoning_layer.load_state_dict({"reasoning_probabilities": rmat}, - # strict=True) - - print(model.reasoning_layer.reasoning_probabilities) - # import sys - # sys.exit() - - # Model summary - print(model) + model = CBC( + hparams, + data=[x_train, y_train], + similarity=euclidean_similarity, + ) # Callbacks - vis = VisualizationCallback(x_train, y_train) + dvis = VisualizationCallback(x_train, + y_train, + prototype_model=False, + title="CBC Iris Example") # Setup trainer trainer = pl.Trainer( - max_epochs=100, + max_epochs=50, callbacks=[ - vis, + dvis, ], ) # Training loop - trainer.fit(model, train_loader) + trainer.fit(model, train_loader) \ No newline at end of file diff --git a/examples/cbc_spiral.py b/examples/cbc_spiral.py index 8147ba6..ec13138 100644 --- a/examples/cbc_spiral.py +++ b/examples/cbc_spiral.py @@ -4,9 +4,9 @@ import numpy as np import pytorch_lightning as pl import torch from matplotlib import pyplot as plt +from prototorch.datasets.abstract import NumpyDataset from torch.utils.data import DataLoader -from prototorch.datasets.abstract import NumpyDataset from prototorch.models.cbc import CBC @@ -110,7 +110,7 @@ if __name__ == "__main__": # Pure-positive reasonings new_reasoning = torch.zeros_like( model.reasoning_layer.reasoning_probabilities) - for i, label in enumerate(model.proto_layer.prototype_labels): + for i, label in enumerate(model.component_layer.prototype_labels): new_reasoning[0][0][i][int(label)] = 1.0 model.reasoning_layer.reasoning_probabilities.data = new_reasoning diff --git a/examples/cbc_spiral_with_glvq_start.py b/examples/cbc_spiral_with_glvq_start.py index b6122e7..bfc62c4 100644 --- a/examples/cbc_spiral_with_glvq_start.py +++ b/examples/cbc_spiral_with_glvq_start.py @@ -8,9 +8,9 @@ import numpy as np import pytorch_lightning as pl import torch from matplotlib import pyplot as plt +from prototorch.datasets.abstract import NumpyDataset from torch.utils.data import DataLoader -from prototorch.datasets.abstract import NumpyDataset from prototorch.models.cbc import CBC from prototorch.models.glvq import GLVQ @@ -132,11 +132,12 @@ if __name__ == "__main__": train(glvq_model, x_train, y_train, train_loader, epochs=10) # Transfer Prototypes - cbc_model.proto_layer.load_state_dict(glvq_model.proto_layer.state_dict()) + cbc_model.component_layer.load_state_dict( + glvq_model.proto_layer.state_dict()) # Pure-positive reasonings new_reasoning = torch.zeros_like( cbc_model.reasoning_layer.reasoning_probabilities) - for i, label in enumerate(cbc_model.proto_layer.prototype_labels): + for i, label in enumerate(cbc_model.component_layer.prototype_labels): new_reasoning[0][0][i][int(label)] = 1.0 new_reasoning[1][0][i][1 - int(label)] = 1.0 diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 5ada61d..3e698f7 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -1,86 +1,16 @@ """GLVQ example using the Iris dataset.""" -import argparse - -import numpy as np import pytorch_lightning as pl import torch -from matplotlib import pyplot as plt +from prototorch.components import initializers as cinit +from prototorch.datasets.abstract import NumpyDataset from sklearn.datasets import load_iris from torch.utils.data import DataLoader -from prototorch.datasets.abstract import NumpyDataset +from prototorch.models.callbacks.visualization import VisGLVQ2D from prototorch.models.glvq import GLVQ - -class GLVQIris(GLVQ): - @staticmethod - def add_model_specific_args(parent_parser): - parser = argparse.ArgumentParser(parents=[parent_parser], - add_help=False) - parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--lr", type=float, default=1e-1) - parser.add_argument("--batch_size", type=int, default=150) - parser.add_argument("--input_dim", type=int, default=2) - parser.add_argument("--nclasses", type=int, default=3) - parser.add_argument("--prototypes_per_class", type=int, default=3) - parser.add_argument("--prototype_initializer", - type=str, - default="stratified_mean") - return parser - - -class VisualizationCallback(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Prototype Visualization", - cmap="viridis"): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - - def on_epoch_end(self, trainer, pl_module): - protos = pl_module.prototypes - plabels = pl_module.prototype_labels - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=plabels, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - - if __name__ == "__main__": - # For best-practices when using `argparse` with `pytorch_lightning`, see - # https://pytorch-lightning.readthedocs.io/en/stable/common/hyperparameters.html - parser = argparse.ArgumentParser() - # Dataset x_train, y_train = load_iris(return_X_y=True) x_train = x_train[:, [0, 2]] @@ -89,43 +19,23 @@ if __name__ == "__main__": # Dataloaders train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - # Add model specific args - parser = GLVQIris.add_model_specific_args(parser) - - # Callbacks - vis = VisualizationCallback(x_train, y_train) - - # Automatically add trainer-specific-args like `--gpus`, `--num_nodes` etc. - parser = pl.Trainer.add_argparse_args(parser) - - # Setup trainer - trainer = pl.Trainer.from_argparse_args( - parser, - max_epochs=10, - callbacks=[ - vis, - ], # comment this line out to disable the visualization + # Hyperparameters + hparams = dict( + nclasses=3, + prototypes_per_class=2, + prototype_initializer=cinit.StratifiedMeanInitializer( + torch.Tensor(x_train), torch.Tensor(y_train)), + lr=0.01, ) - # trainer.tune(model) # Initialize the model - args = parser.parse_args() - model = GLVQIris(args, data=[x_train, y_train]) + model = GLVQ(hparams) - # Model summary - print(model) + # Setup trainer + trainer = pl.Trainer( + max_epochs=50, + callbacks=[VisGLVQ2D(x_train, y_train)], + ) # Training loop trainer.fit(model, train_loader) - - # Save the model manually (use `pl.callbacks.ModelCheckpoint` to automate) - ckpt = "glvq_iris.ckpt" - trainer.save_checkpoint(ckpt) - - # Load the checkpoint - new_model = GLVQIris.load_from_checkpoint(checkpoint_path=ckpt) - - print(new_model) - - # Continue training - trainer.fit(new_model, train_loader) # TODO See why this fails! diff --git a/examples/glvq_iris_v1.py b/examples/glvq_iris_v1.py deleted file mode 100644 index 893e177..0000000 --- a/examples/glvq_iris_v1.py +++ /dev/null @@ -1,40 +0,0 @@ -"""GLVQ example using the Iris dataset.""" - -import pytorch_lightning as pl -import torch -from prototorch.components import initializers as cinit -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.callbacks.visualization import VisGLVQ2D -from prototorch.models.glvq import GLVQ -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader - -if __name__ == "__main__": - # Dataset - x_train, y_train = load_iris(return_X_y=True) - x_train = x_train[:, [0, 2]] - train_ds = NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - - # Hyperparameters - hparams = dict( - nclasses=3, - prototypes_per_class=2, - prototype_initializer=cinit.StratifiedMeanInitializer( - torch.Tensor(x_train), torch.Tensor(y_train)), - lr=0.01, - ) - - # Initialize the model - model = GLVQ(hparams) - - # Setup trainer - trainer = pl.Trainer( - max_epochs=50, - callbacks=[VisGLVQ2D(x_train, y_train)], - ) - - # Training loop - trainer.fit(model, train_loader) diff --git a/examples/glvq_mnist.py b/examples/glvq_mnist.py index 6ab5091..fd96d3f 100644 --- a/examples/glvq_mnist.py +++ b/examples/glvq_mnist.py @@ -7,6 +7,7 @@ import argparse import pytorch_lightning as pl import torchvision +from prototorch.components import initializers as cinit from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import MNIST @@ -92,12 +93,12 @@ if __name__ == "__main__": input_dim=28 * 28, nclasses=10, prototypes_per_class=1, - prototype_initializer="stratified_mean", + prototype_initializer=cinit.StratifiedMeanInitializer(x, y), lr=args.lr, ) # Initialize the model - model = ImageGLVQ(hparams, data=[x, y]) + model = ImageGLVQ(hparams) # Model summary print(model) diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py index ccfa191..824bd0d 100644 --- a/examples/glvq_spiral.py +++ b/examples/glvq_spiral.py @@ -5,9 +5,10 @@ import torch from prototorch.components import initializers as cinit from prototorch.datasets.abstract import NumpyDataset from prototorch.datasets.spiral import make_spiral +from torch.utils.data import DataLoader + from prototorch.models.callbacks.visualization import VisGLVQ2D from prototorch.models.glvq import GLVQ -from torch.utils.data import DataLoader class StopOnNaN(pl.Callback): diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py index 7bee5fd..ce15e56 100644 --- a/examples/gmlvq_iris.py +++ b/examples/gmlvq_iris.py @@ -4,11 +4,12 @@ import pytorch_lightning as pl import torch from prototorch.components import initializers as cinit from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D -from prototorch.models.glvq import GMLVQ from sklearn.datasets import load_iris from torch.utils.data import DataLoader +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import GMLVQ + if __name__ == "__main__": # Dataset x_train, y_train = load_iris(return_X_y=True) diff --git a/examples/gmlvq_tecator.py b/examples/gmlvq_tecator.py index 8e0cf0d..d5a92e9 100644 --- a/examples/gmlvq_tecator.py +++ b/examples/gmlvq_tecator.py @@ -1,12 +1,12 @@ """GMLVQ example using the Tecator dataset.""" import pytorch_lightning as pl -import torch from prototorch.components import initializers as cinit from prototorch.datasets.tecator import Tecator +from torch.utils.data import DataLoader + from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D from prototorch.models.glvq import GMLVQ -from torch.utils.data import DataLoader if __name__ == "__main__": # Dataset diff --git a/examples/ng_iris.py b/examples/ng_iris.py index 16e954a..8b3f51f 100644 --- a/examples/ng_iris.py +++ b/examples/ng_iris.py @@ -1,15 +1,14 @@ """Neural Gas example using the Iris dataset.""" -import numpy as np import pytorch_lightning as pl -from matplotlib import pyplot as plt from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.callbacks.visualization import VisNG2D -from prototorch.models.neural_gas import NeuralGas from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler from torch.utils.data import DataLoader +from prototorch.models.callbacks.visualization import VisNG2D +from prototorch.models.neural_gas import NeuralGas + if __name__ == "__main__": # Dataset x_train, y_train = load_iris(return_X_y=True) diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index 897e3f0..8a4530f 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -2,14 +2,16 @@ import pytorch_lightning as pl import torch -from prototorch.components import (StratifiedMeanInitializer, - StratifiedSelectionInitializer) +from prototorch.components import ( + StratifiedMeanInitializer +) from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D -from prototorch.models.glvq import SiameseGLVQ from sklearn.datasets import load_iris from torch.utils.data import DataLoader +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import SiameseGLVQ + class Backbone(torch.nn.Module): def __init__(self, input_size=4, hidden_size=10, latent_size=2): diff --git a/prototorch/models/cbc.py b/prototorch/models/cbc.py index 5c20260..eff85df 100644 --- a/prototorch/models/cbc.py +++ b/prototorch/models/cbc.py @@ -1,10 +1,9 @@ import pytorch_lightning as pl import torch import torchmetrics - +from prototorch.components.components import Components from prototorch.functions.distances import euclidean_distance from prototorch.functions.similarities import cosine_similarity -from prototorch.modules.prototypes import Prototypes1D def rescaled_cosine_similarity(x, y): @@ -93,12 +92,8 @@ class CBC(pl.LightningModule): super().__init__() self.save_hyperparameters(hparams) self.margin = margin - self.proto_layer = Prototypes1D( - input_dim=self.hparams.input_dim, - nclasses=self.hparams.nclasses, - prototypes_per_class=self.hparams.prototypes_per_class, - prototype_initializer=self.hparams.prototype_initializer, - **kwargs) + self.component_layer = Components(self.hparams.num_components, + self.hparams.component_initializer) # self.similarity = CosineSimilarity() self.similarity = similarity self.backbone = backbone_class() @@ -110,7 +105,7 @@ class CBC(pl.LightningModule): @property def components(self): - return self.proto_layer.prototypes.detach().cpu() + return self.component_layer.components.detach().cpu() @property def reasonings(self): @@ -126,7 +121,7 @@ class CBC(pl.LightningModule): def forward(self, x): self.sync_backbones() - protos, _ = self.proto_layer() + protos = self.component_layer() latent_x = self.backbone(x) latent_protos = self.backbone_dependent(protos) @@ -167,4 +162,4 @@ class ImageCBC(CBC): """ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): # super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) - self.proto_layer.prototypes.data.clamp_(0.0, 1.0) + self.component_layer.prototypes.data.clamp_(0.0, 1.0) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index fa0a46a..efd76b7 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -1,4 +1,3 @@ -import pytorch_lightning as pl import torch import torchmetrics from prototorch.components import LabeledComponents @@ -7,7 +6,6 @@ from prototorch.functions.competitions import wtac from prototorch.functions.distances import (euclidean_distance, squared_euclidean_distance) from prototorch.functions.losses import glvq_loss -from prototorch.modules.prototypes import Prototypes1D from .abstract import AbstractPrototypeModel @@ -55,7 +53,6 @@ class GLVQ(AbstractPrototypeModel): with torch.no_grad(): preds = wtac(dis, plabels) # `.int()` because FloatTensors are assumed to be class probabilities - self.train_acc(preds.int(), y.int()) # Logging self.log("train_loss", loss) diff --git a/prototorch/models/neural_gas.py b/prototorch/models/neural_gas.py index d98bfa0..bebd289 100644 --- a/prototorch/models/neural_gas.py +++ b/prototorch/models/neural_gas.py @@ -1,9 +1,7 @@ -import pytorch_lightning as pl import torch from prototorch.components import Components from prototorch.components import initializers as cinit from prototorch.functions.distances import euclidean_distance -from prototorch.modules import Prototypes1D from prototorch.modules.losses import NeuralGasEnergy from .abstract import AbstractPrototypeModel From 5a2f4f617003ea511c3f39c8ffad74b28115933c Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Thu, 6 May 2021 18:02:01 +0200 Subject: [PATCH 14/27] Revert deletion of training accuracy. --- prototorch/models/glvq.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index efd76b7..ff1da61 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -52,6 +52,8 @@ class GLVQ(AbstractPrototypeModel): # Compute training accuracy with torch.no_grad(): preds = wtac(dis, plabels) + + self.train_acc(preds.int(), y.int()) # `.int()` because FloatTensors are assumed to be class probabilities # Logging From 3df282a0af4e911c6afea54a233f537f179499b8 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Thu, 6 May 2021 18:41:33 +0200 Subject: [PATCH 15/27] Increase visualization pause. --- prototorch/models/callbacks/visualization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototorch/models/callbacks/visualization.py b/prototorch/models/callbacks/visualization.py index 98a1889..a7486a8 100644 --- a/prototorch/models/callbacks/visualization.py +++ b/prototorch/models/callbacks/visualization.py @@ -378,7 +378,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract): if self.tensorboard: self.add_to_tensorboard(trainer, pl_module) if not self.block: - plt.pause(0.01) + plt.pause(0.05) else: plt.show(block=True) From 79e5eaa69ab501bcac15293c314b249b3eb0c538 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Thu, 6 May 2021 18:41:50 +0200 Subject: [PATCH 16/27] Rename GMLVQ Example. --- examples/{gmlvq_tecator.py => liranmlvq_tecator.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename examples/{gmlvq_tecator.py => liranmlvq_tecator.py} (95%) diff --git a/examples/gmlvq_tecator.py b/examples/liranmlvq_tecator.py similarity index 95% rename from examples/gmlvq_tecator.py rename to examples/liranmlvq_tecator.py index d5a92e9..d9fff1e 100644 --- a/examples/gmlvq_tecator.py +++ b/examples/liranmlvq_tecator.py @@ -1,4 +1,4 @@ -"""GMLVQ example using the Tecator dataset.""" +"""Limited Rank MLVQ example using the Tecator dataset.""" import pytorch_lightning as pl from prototorch.components import initializers as cinit From 4bbe73e3a9f7e63eb4adc76c32552b25f822bbaf Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Thu, 6 May 2021 18:42:06 +0200 Subject: [PATCH 17/27] Add GRLVQ with examples. --- examples/grlvq_iris.py | 62 +++++++++++++++++++++++++++++++++++++++ examples/grlvq_spiral.py | 57 +++++++++++++++++++++++++++++++++++ prototorch/models/glvq.py | 39 ++++++++++++++++++++++-- 3 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 examples/grlvq_iris.py create mode 100644 examples/grlvq_spiral.py diff --git a/examples/grlvq_iris.py b/examples/grlvq_iris.py new file mode 100644 index 0000000..01c31cb --- /dev/null +++ b/examples/grlvq_iris.py @@ -0,0 +1,62 @@ +"""GMLVQ example using all four dimensions of the Iris dataset.""" + +import pytorch_lightning as pl +import torch +from prototorch.components import initializers as cinit +from prototorch.datasets.abstract import NumpyDataset +from sklearn.datasets import load_iris +from torch.utils.data import DataLoader + +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import GRLVQ + +from sklearn.preprocessing import StandardScaler + + +class PrintRelevanceCallback(pl.Callback): + def on_epoch_end(self, trainer, pl_module: GRLVQ): + print(pl_module.relevance_profile) + + +if __name__ == "__main__": + # Dataset + x_train, y_train = load_iris(return_X_y=True) + x_train = x_train[:, [0, 2]] + scaler = StandardScaler() + scaler.fit(x_train) + x_train = scaler.transform(x_train) + train_ds = NumpyDataset(x_train, y_train) + + # Dataloaders + train_loader = DataLoader(train_ds, + num_workers=0, + batch_size=50, + shuffle=True) + + # Hyperparameters + hparams = dict( + nclasses=3, + prototypes_per_class=1, + #prototype_initializer=cinit.SMI(torch.Tensor(x_train), + # torch.Tensor(y_train)), + prototype_initializer=cinit.UniformInitializer(2), + input_dim=x_train.shape[1], + lr=0.1, + #transfer_function="sigmoid_beta", + ) + + # Initialize the model + model = GRLVQ(hparams) + + # Model summary + print(model) + + # Callbacks + vis = VisSiameseGLVQ2D(x_train, y_train) + debug = PrintRelevanceCallback() + + # Setup trainer + trainer = pl.Trainer(max_epochs=200, callbacks=[vis, debug]) + + # Training loop + trainer.fit(model, train_loader) diff --git a/examples/grlvq_spiral.py b/examples/grlvq_spiral.py new file mode 100644 index 0000000..61d754c --- /dev/null +++ b/examples/grlvq_spiral.py @@ -0,0 +1,57 @@ +"""GMLVQ example using all four dimensions of the Iris dataset.""" + +import pytorch_lightning as pl +import torch +from prototorch.components import initializers as cinit +from prototorch.datasets.abstract import NumpyDataset +from sklearn.datasets import load_iris +from torch.utils.data import DataLoader + +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import GRLVQ + +from sklearn.preprocessing import StandardScaler + +from prototorch.datasets.spiral import make_spiral + + +class PrintRelevanceCallback(pl.Callback): + def on_epoch_end(self, trainer, pl_module: GRLVQ): + print(pl_module.relevance_profile) + + +if __name__ == "__main__": + # Dataset + x_train, y_train = make_spiral(n_samples=1000, noise=0.3) + train_ds = NumpyDataset(x_train, y_train) + + # Dataloaders + train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + + # Hyperparameters + hparams = dict( + nclasses=2, + prototypes_per_class=20, + prototype_initializer=cinit.SSI(torch.Tensor(x_train), + torch.Tensor(y_train)), + #prototype_initializer=cinit.UniformInitializer(2), + input_dim=x_train.shape[1], + lr=0.1, + #transfer_function="sigmoid_beta", + ) + + # Initialize the model + model = GRLVQ(hparams) + + # Model summary + print(model) + + # Callbacks + vis = VisSiameseGLVQ2D(x_train, y_train) + debug = PrintRelevanceCallback() + + # Setup trainer + trainer = pl.Trainer(max_epochs=200, callbacks=[vis, debug]) + + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index ff1da61..868d4c9 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -3,7 +3,7 @@ import torchmetrics from prototorch.components import LabeledComponents from prototorch.functions.activations import get_activation from prototorch.functions.competitions import wtac -from prototorch.functions.distances import (euclidean_distance, +from prototorch.functions.distances import (euclidean_distance, omega_distance, squared_euclidean_distance) from prototorch.functions.losses import glvq_loss @@ -32,7 +32,7 @@ class GLVQ(AbstractPrototypeModel): @property def prototype_labels(self): - return self.proto_layer.component_labels.detach().numpy() + return self.proto_layer.component_labels.detach().cpu() def forward(self, x): protos, _ = self.proto_layer() @@ -148,6 +148,41 @@ class SiameseGLVQ(GLVQ): return y_pred.numpy() +class GRLVQ(GLVQ): + """Generalized Relevance Learning Vector Quantization.""" + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.relevances = torch.nn.parameter.Parameter( + torch.ones(self.hparams.input_dim)) + + def forward(self, x): + protos, _ = self.proto_layer() + dis = omega_distance(x, protos, torch.diag(self.relevances)) + return dis + + def backbone(self, x): + return x @ torch.diag(self.relevances) + + @property + def relevance_profile(self): + return self.relevances.detach().cpu() + + def predict_latent(self, x): + """Predict `x` assuming it is already embedded in the latent space. + + Only the prototypes are embedded in the latent space using the + backbone. + + """ + # model.eval() # ?! + with torch.no_grad(): + protos, plabels = self.proto_layer() + latent_protos = protos @ torch.diag(self.relevances) + d = squared_euclidean_distance(x, latent_protos) + y_pred = wtac(d, plabels) + return y_pred.numpy() + + class GMLVQ(GLVQ): """Generalized Matrix Learning Vector Quantization.""" def __init__(self, hparams, **kwargs): From 1b9bcf21f6ad478b7fec0d8abfb502688c20f184 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 6 May 2021 18:50:37 +0200 Subject: [PATCH 18/27] Fix typo --- examples/{liranmlvq_tecator.py => liramlvq_tecator.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{liranmlvq_tecator.py => liramlvq_tecator.py} (100%) diff --git a/examples/liranmlvq_tecator.py b/examples/liramlvq_tecator.py similarity index 100% rename from examples/liranmlvq_tecator.py rename to examples/liramlvq_tecator.py From e87663d10cb4fd0cb84dd7b2e78b516fa8f08d87 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 13:07:30 +0200 Subject: [PATCH 19/27] Make siamese example script reproducible --- examples/siamese_glvq_iris.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index 8a4530f..d117f4f 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -2,18 +2,16 @@ import pytorch_lightning as pl import torch -from prototorch.components import ( - StratifiedMeanInitializer -) +from prototorch.components import initializers as cinit from prototorch.datasets.abstract import NumpyDataset +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import SiameseGLVQ from sklearn.datasets import load_iris from torch.utils.data import DataLoader -from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D -from prototorch.models.glvq import SiameseGLVQ - class Backbone(torch.nn.Module): + """Two fully connected layers with ReLU activation.""" def __init__(self, input_size=4, hidden_size=10, latent_size=2): super().__init__() self.input_size = input_size @@ -24,7 +22,9 @@ class Backbone(torch.nn.Module): self.relu = torch.nn.ReLU() def forward(self, x): - return self.relu(self.dense2(self.relu(self.dense1(x)))) + x = self.relu(self.dense1(x)) + out = self.relu(self.dense2(x)) + return out if __name__ == "__main__": @@ -32,16 +32,20 @@ if __name__ == "__main__": x_train, y_train = load_iris(return_X_y=True) train_ds = NumpyDataset(x_train, y_train) + # Reproducibility + pl.utilities.seed.seed_everything(seed=2) + # Dataloaders train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) # Hyperparameters hparams = dict( nclasses=3, - prototypes_per_class=1, - prototype_initializer=StratifiedMeanInitializer( - torch.Tensor(x_train), torch.Tensor(y_train)), - lr=0.01, + prototypes_per_class=2, + prototype_initializer=cinit.SMI(torch.Tensor(x_train), + torch.Tensor(y_train)), + proto_lr=0.001, + bb_lr=0.001, ) # Initialize the model @@ -54,7 +58,7 @@ if __name__ == "__main__": print(model) # Callbacks - vis = VisSiameseGLVQ2D(x_train, y_train) + vis = VisSiameseGLVQ2D(x_train, y_train, border=0.1) # Setup trainer trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) From f2541acde9b99ff0016a920942103b8358a67026 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:21:35 +0200 Subject: [PATCH 20/27] Unclutter the examples folder --- examples/cbc_circle.py | 112 ------------------- examples/cbc_mnist.py | 128 --------------------- examples/cbc_spiral.py | 135 ----------------------- examples/cbc_spiral_with_glvq_start.py | 147 ------------------------- examples/glvq_mnist.py | 119 -------------------- examples/grlvq_iris.py | 62 ----------- examples/grlvq_spiral.py | 57 ---------- 7 files changed, 760 deletions(-) delete mode 100644 examples/cbc_circle.py delete mode 100644 examples/cbc_mnist.py delete mode 100644 examples/cbc_spiral.py delete mode 100644 examples/cbc_spiral_with_glvq_start.py delete mode 100644 examples/glvq_mnist.py delete mode 100644 examples/grlvq_iris.py delete mode 100644 examples/grlvq_spiral.py diff --git a/examples/cbc_circle.py b/examples/cbc_circle.py deleted file mode 100644 index 1ed88a8..0000000 --- a/examples/cbc_circle.py +++ /dev/null @@ -1,112 +0,0 @@ -"""CBC example using the Iris dataset.""" - -import numpy as np -import pytorch_lightning as pl -import torch -from matplotlib import pyplot as plt -from prototorch.components import initializers as cinit -from prototorch.datasets.abstract import NumpyDataset -from sklearn.datasets import make_circles -from torch.utils.data import DataLoader - -from prototorch.models.cbc import CBC, euclidean_similarity - - -class VisualizationCallback(pl.Callback): - def __init__( - self, - x_train, - y_train, - prototype_model=True, - title="Prototype Visualization", - cmap="viridis", - ): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - self.prototype_model = prototype_model - - def on_epoch_end(self, trainer, pl_module): - if self.prototype_model: - protos = pl_module.components - color = pl_module.prototype_labels - else: - protos = pl_module.components - color = "k" - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=color, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - - -if __name__ == "__main__": - # Dataset - x_train, y_train = make_circles(n_samples=300, - shuffle=True, - noise=0.05, - random_state=None, - factor=0.5) - train_ds = NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - - # Hyperparameters - hparams = dict( - input_dim=x_train.shape[1], - nclasses=len(np.unique(y_train)), - num_components=5, - component_initializer=cinit.RandomInitializer(x_train.shape[1]), - lr=0.01, - ) - - # Initialize the model - model = CBC( - hparams, - data=[x_train, y_train], - similarity=euclidean_similarity, - ) - - # Callbacks - dvis = VisualizationCallback(x_train, - y_train, - prototype_model=False, - title="CBC Circle Example") - - # Setup trainer - trainer = pl.Trainer( - max_epochs=50, - callbacks=[ - dvis, - ], - ) - - # Training loop - trainer.fit(model, train_loader) diff --git a/examples/cbc_mnist.py b/examples/cbc_mnist.py deleted file mode 100644 index 8546025..0000000 --- a/examples/cbc_mnist.py +++ /dev/null @@ -1,128 +0,0 @@ -"""CBC example using the MNIST dataset. - -This script also shows how to use Tensorboard for visualizing the prototypes. -""" - -import argparse - -import pytorch_lightning as pl -import torchvision -from torch.utils.data import DataLoader -from torchvision import transforms -from torchvision.datasets import MNIST - -from prototorch.models.cbc import CBC, ImageCBC, euclidean_similarity - - -class VisualizationCallback(pl.Callback): - def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2): - super().__init__() - self.to_shape = to_shape - self.nrow = nrow - - def on_epoch_end(self, trainer, pl_module: ImageCBC): - tb = pl_module.logger.experiment - - # components - components = pl_module.components - components_img = components.reshape(self.to_shape) - grid = torchvision.utils.make_grid(components_img, nrow=self.nrow) - tb.add_image( - tag="MNIST Components", - img_tensor=grid, - global_step=trainer.current_epoch, - dataformats="CHW", - ) - # Reasonings - reasonings = pl_module.reasonings - tb.add_images( - tag="MNIST Reasoning", - img_tensor=reasonings, - global_step=trainer.current_epoch, - dataformats="NCHW", - ) - - -if __name__ == "__main__": - # Arguments - parser = argparse.ArgumentParser() - parser.add_argument("--epochs", - type=int, - default=10, - help="Epochs to train.") - parser.add_argument("--lr", - type=float, - default=0.001, - help="Learning rate.") - parser.add_argument("--batch_size", - type=int, - default=256, - help="Batch size.") - parser.add_argument("--gpus", - type=int, - default=0, - help="Number of GPUs to use.") - parser.add_argument("--ppc", - type=int, - default=1, - help="Prototypes-Per-Class.") - args = parser.parse_args() - - # Dataset - mnist_train = MNIST( - "./datasets", - train=True, - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ]), - ) - mnist_test = MNIST( - "./datasets", - train=False, - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ]), - ) - - # Dataloaders - train_loader = DataLoader(mnist_train, batch_size=32) - test_loader = DataLoader(mnist_test, batch_size=32) - - # Grab the full dataset to warm-start prototypes - x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train)))) - x = x.view(len(mnist_train), -1) - - # Hyperparameters - hparams = dict( - input_dim=28 * 28, - nclasses=10, - prototypes_per_class=args.ppc, - prototype_initializer="randn", - lr=0.01, - similarity=euclidean_similarity, - ) - - # Initialize the model - model = CBC(hparams, data=[x, y]) - # Model summary - print(model) - - # Callbacks - vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc) - - # Setup trainer - trainer = pl.Trainer( - gpus=args.gpus, # change to use GPUs for training - max_epochs=args.epochs, - callbacks=[vis], - track_grad_norm=2, - # accelerator="ddp_cpu", # DEBUG-ONLY - # num_processes=2, # DEBUG-ONLY - ) - - # Training loop - trainer.fit(model, train_loader, test_loader) diff --git a/examples/cbc_spiral.py b/examples/cbc_spiral.py deleted file mode 100644 index ec13138..0000000 --- a/examples/cbc_spiral.py +++ /dev/null @@ -1,135 +0,0 @@ -"""CBC example using the Iris dataset.""" - -import numpy as np -import pytorch_lightning as pl -import torch -from matplotlib import pyplot as plt -from prototorch.datasets.abstract import NumpyDataset -from torch.utils.data import DataLoader - -from prototorch.models.cbc import CBC - - -class VisualizationCallback(pl.Callback): - def __init__( - self, - x_train, - y_train, - prototype_model=True, - title="Prototype Visualization", - cmap="viridis", - ): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - self.prototype_model = prototype_model - - def on_epoch_end(self, trainer, pl_module): - if self.prototype_model: - protos = pl_module.prototypes - color = pl_module.prototype_labels - else: - protos = pl_module.components - color = "k" - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=color, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - - -def make_spirals(n_samples=500, noise=0.3): - def get_samples(n, delta_t): - points = [] - for i in range(n): - r = i / n_samples * 5 - t = 1.75 * i / n * 2 * np.pi + delta_t - x = r * np.sin(t) + np.random.rand(1) * noise - y = r * np.cos(t) + np.random.rand(1) * noise - points.append([x, y]) - return points - - n = n_samples // 2 - positive = get_samples(n=n, delta_t=0) - negative = get_samples(n=n, delta_t=np.pi) - x = np.concatenate( - [np.array(positive).reshape(n, -1), - np.array(negative).reshape(n, -1)], - axis=0) - y = np.concatenate([np.zeros(n), np.ones(n)]) - return x, y - - -if __name__ == "__main__": - # Dataset - x_train, y_train = make_spirals(n_samples=1000, noise=0.3) - train_ds = NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - - # Hyperparameters - hparams = dict( - input_dim=x_train.shape[1], - nclasses=2, - prototypes_per_class=40, - prototype_initializer="stratified_random", - lr=0.05, - ) - - # Initialize the model - model_class = CBC - model = model_class(hparams, data=[x_train, y_train]) - - # Pure-positive reasonings - new_reasoning = torch.zeros_like( - model.reasoning_layer.reasoning_probabilities) - for i, label in enumerate(model.component_layer.prototype_labels): - new_reasoning[0][0][i][int(label)] = 1.0 - - model.reasoning_layer.reasoning_probabilities.data = new_reasoning - - # Model summary - print(model) - - # Callbacks - vis = VisualizationCallback(x_train, - y_train, - prototype_model=hasattr(model, "prototypes")) - - # Setup trainer - trainer = pl.Trainer( - max_epochs=500, - callbacks=[ - vis, - ], - ) - - # Training loop - trainer.fit(model, train_loader) diff --git a/examples/cbc_spiral_with_glvq_start.py b/examples/cbc_spiral_with_glvq_start.py deleted file mode 100644 index bfc62c4..0000000 --- a/examples/cbc_spiral_with_glvq_start.py +++ /dev/null @@ -1,147 +0,0 @@ -"""CBC example using the spirals dataset. - -This example shows how to jump start a model by transferring weights from -another more stable model. -""" - -import numpy as np -import pytorch_lightning as pl -import torch -from matplotlib import pyplot as plt -from prototorch.datasets.abstract import NumpyDataset -from torch.utils.data import DataLoader - -from prototorch.models.cbc import CBC -from prototorch.models.glvq import GLVQ - - -class VisualizationCallback(pl.Callback): - def __init__( - self, - x_train, - y_train, - prototype_model=True, - title="Prototype Visualization", - cmap="viridis", - ): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - self.prototype_model = prototype_model - - def on_epoch_end(self, trainer, pl_module): - if self.prototype_model: - protos = pl_module.prototypes - color = pl_module.prototype_labels - else: - protos = pl_module.components - color = "k" - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=color, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - - -def make_spirals(n_samples=500, noise=0.3): - def get_samples(n, delta_t): - points = [] - for i in range(n): - r = i / n_samples * 5 - t = 1.75 * i / n * 2 * np.pi + delta_t - x = r * np.sin(t) + np.random.rand(1) * noise - y = r * np.cos(t) + np.random.rand(1) * noise - points.append([x, y]) - return points - - n = n_samples // 2 - positive = get_samples(n=n, delta_t=0) - negative = get_samples(n=n, delta_t=np.pi) - x = np.concatenate( - [np.array(positive).reshape(n, -1), - np.array(negative).reshape(n, -1)], - axis=0) - y = np.concatenate([np.zeros(n), np.ones(n)]) - return x, y - - -def train(model, x_train, y_train, train_loader, epochs=100): - # Callbacks - vis = VisualizationCallback(x_train, - y_train, - prototype_model=hasattr(model, "prototypes")) - # Setup trainer - trainer = pl.Trainer( - max_epochs=epochs, - callbacks=[ - vis, - ], - ) - # Training loop - trainer.fit(model, train_loader) - - -if __name__ == "__main__": - # Dataset - x_train, y_train = make_spirals(n_samples=1000, noise=0.3) - train_ds = NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - - # Hyperparameters - hparams = dict( - input_dim=x_train.shape[1], - nclasses=2, - prototypes_per_class=40, - prototype_initializer="stratified_random", - lr=0.05, - ) - - # Initialize the model - glvq_model = GLVQ(hparams, data=[x_train, y_train]) - cbc_model = CBC(hparams, data=[x_train, y_train]) - - # Train GLVQ - train(glvq_model, x_train, y_train, train_loader, epochs=10) - - # Transfer Prototypes - cbc_model.component_layer.load_state_dict( - glvq_model.proto_layer.state_dict()) - # Pure-positive reasonings - new_reasoning = torch.zeros_like( - cbc_model.reasoning_layer.reasoning_probabilities) - for i, label in enumerate(cbc_model.component_layer.prototype_labels): - new_reasoning[0][0][i][int(label)] = 1.0 - new_reasoning[1][0][i][1 - int(label)] = 1.0 - - cbc_model.reasoning_layer.reasoning_probabilities.data = new_reasoning - - # Train CBC - train(cbc_model, x_train, y_train, train_loader, epochs=50) diff --git a/examples/glvq_mnist.py b/examples/glvq_mnist.py deleted file mode 100644 index fd96d3f..0000000 --- a/examples/glvq_mnist.py +++ /dev/null @@ -1,119 +0,0 @@ -"""GLVQ example using the MNIST dataset. - -This script also shows how to use Tensorboard for visualizing the prototypes. -""" - -import argparse - -import pytorch_lightning as pl -import torchvision -from prototorch.components import initializers as cinit -from torch.utils.data import DataLoader -from torchvision import transforms -from torchvision.datasets import MNIST - -from prototorch.models.glvq import ImageGLVQ - - -class VisualizationCallback(pl.Callback): - def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2): - super().__init__() - self.to_shape = to_shape - self.nrow = nrow - - def on_epoch_end(self, trainer, pl_module): - protos = pl_module.proto_layer.prototypes.detach().cpu() - protos_img = protos.reshape(self.to_shape) - grid = torchvision.utils.make_grid(protos_img, nrow=self.nrow) - # grid = grid.permute((1, 2, 0)) - tb = pl_module.logger.experiment - tb.add_image( - tag="MNIST Prototypes", - img_tensor=grid, - global_step=trainer.current_epoch, - dataformats="CHW", - ) - - -if __name__ == "__main__": - # Arguments - parser = argparse.ArgumentParser() - parser.add_argument("--epochs", - type=int, - default=10, - help="Epochs to train.") - parser.add_argument("--lr", - type=float, - default=0.001, - help="Learning rate.") - parser.add_argument("--batch_size", - type=int, - default=256, - help="Batch size.") - parser.add_argument("--gpus", - type=int, - default=0, - help="Number of GPUs to use.") - parser.add_argument("--ppc", - type=int, - default=1, - help="Prototypes-Per-Class.") - args = parser.parse_args() - - # Dataset - mnist_train = MNIST( - "./datasets", - train=True, - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ]), - ) - mnist_test = MNIST( - "./datasets", - train=False, - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ]), - ) - - # Dataloaders - train_loader = DataLoader(mnist_train, batch_size=1024) - test_loader = DataLoader(mnist_test, batch_size=1024) - - # Grab the full dataset to warm-start prototypes - x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train)))) - x = x.view(len(mnist_train), -1) - - # Hyperparameters - hparams = dict( - input_dim=28 * 28, - nclasses=10, - prototypes_per_class=1, - prototype_initializer=cinit.StratifiedMeanInitializer(x, y), - lr=args.lr, - ) - - # Initialize the model - model = ImageGLVQ(hparams) - - # Model summary - print(model) - - # Callbacks - vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc) - - # Setup trainer - trainer = pl.Trainer( - gpus=args.gpus, # change to use GPUs for training - max_epochs=args.epochs, - callbacks=[vis], - # accelerator="ddp_cpu", # DEBUG-ONLY - # num_processes=2, # DEBUG-ONLY - ) - - # Training loop - trainer.fit(model, train_loader, test_loader) diff --git a/examples/grlvq_iris.py b/examples/grlvq_iris.py deleted file mode 100644 index 01c31cb..0000000 --- a/examples/grlvq_iris.py +++ /dev/null @@ -1,62 +0,0 @@ -"""GMLVQ example using all four dimensions of the Iris dataset.""" - -import pytorch_lightning as pl -import torch -from prototorch.components import initializers as cinit -from prototorch.datasets.abstract import NumpyDataset -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader - -from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D -from prototorch.models.glvq import GRLVQ - -from sklearn.preprocessing import StandardScaler - - -class PrintRelevanceCallback(pl.Callback): - def on_epoch_end(self, trainer, pl_module: GRLVQ): - print(pl_module.relevance_profile) - - -if __name__ == "__main__": - # Dataset - x_train, y_train = load_iris(return_X_y=True) - x_train = x_train[:, [0, 2]] - scaler = StandardScaler() - scaler.fit(x_train) - x_train = scaler.transform(x_train) - train_ds = NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = DataLoader(train_ds, - num_workers=0, - batch_size=50, - shuffle=True) - - # Hyperparameters - hparams = dict( - nclasses=3, - prototypes_per_class=1, - #prototype_initializer=cinit.SMI(torch.Tensor(x_train), - # torch.Tensor(y_train)), - prototype_initializer=cinit.UniformInitializer(2), - input_dim=x_train.shape[1], - lr=0.1, - #transfer_function="sigmoid_beta", - ) - - # Initialize the model - model = GRLVQ(hparams) - - # Model summary - print(model) - - # Callbacks - vis = VisSiameseGLVQ2D(x_train, y_train) - debug = PrintRelevanceCallback() - - # Setup trainer - trainer = pl.Trainer(max_epochs=200, callbacks=[vis, debug]) - - # Training loop - trainer.fit(model, train_loader) diff --git a/examples/grlvq_spiral.py b/examples/grlvq_spiral.py deleted file mode 100644 index 61d754c..0000000 --- a/examples/grlvq_spiral.py +++ /dev/null @@ -1,57 +0,0 @@ -"""GMLVQ example using all four dimensions of the Iris dataset.""" - -import pytorch_lightning as pl -import torch -from prototorch.components import initializers as cinit -from prototorch.datasets.abstract import NumpyDataset -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader - -from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D -from prototorch.models.glvq import GRLVQ - -from sklearn.preprocessing import StandardScaler - -from prototorch.datasets.spiral import make_spiral - - -class PrintRelevanceCallback(pl.Callback): - def on_epoch_end(self, trainer, pl_module: GRLVQ): - print(pl_module.relevance_profile) - - -if __name__ == "__main__": - # Dataset - x_train, y_train = make_spiral(n_samples=1000, noise=0.3) - train_ds = NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - - # Hyperparameters - hparams = dict( - nclasses=2, - prototypes_per_class=20, - prototype_initializer=cinit.SSI(torch.Tensor(x_train), - torch.Tensor(y_train)), - #prototype_initializer=cinit.UniformInitializer(2), - input_dim=x_train.shape[1], - lr=0.1, - #transfer_function="sigmoid_beta", - ) - - # Initialize the model - model = GRLVQ(hparams) - - # Model summary - print(model) - - # Callbacks - vis = VisSiameseGLVQ2D(x_train, y_train) - debug = PrintRelevanceCallback() - - # Setup trainer - trainer = pl.Trainer(max_epochs=200, callbacks=[vis, debug]) - - # Training loop - trainer.fit(model, train_loader) From 5f937066bf47150070af73ea89b47b221fe35edc Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:22:54 +0200 Subject: [PATCH 21/27] Move and improve visualization callbacks --- .../{callbacks/visualization.py => vis.py} | 150 ++++++++++-------- 1 file changed, 87 insertions(+), 63 deletions(-) rename prototorch/models/{callbacks/visualization.py => vis.py} (82%) diff --git a/prototorch/models/callbacks/visualization.py b/prototorch/models/vis.py similarity index 82% rename from prototorch/models/callbacks/visualization.py rename to prototorch/models/vis.py index a7486a8..6099bc3 100644 --- a/prototorch/models/callbacks/visualization.py +++ b/prototorch/models/vis.py @@ -9,6 +9,7 @@ from prototorch.utils.celluloid import Camera from prototorch.utils.colors import color_scheme from prototorch.utils.utils import (gif_from_dir, make_directory, prettify_string) +from torch.utils.data import DataLoader, Dataset class VisWeights(pl.Callback): @@ -263,25 +264,54 @@ class VisPointProtos(VisWeights): class Vis2DAbstract(pl.Callback): def __init__(self, - x_train, - y_train, + data, title="Prototype Visualization", cmap="viridis", border=1, + resolution=50, tensorboard=False, show_last_only=False, + pause_time=0.1, block=False): super().__init__() - self.x_train = x_train - self.y_train = y_train + + if isinstance(data, Dataset): + x, y = next(iter(DataLoader(data, batch_size=len(data)))) + x = x.view(len(data), -1) # flatten + else: + x, y = data + self.x_train = x + self.y_train = y + self.title = title self.fig = plt.figure(self.title) self.cmap = cmap self.border = border + self.resolution = resolution self.tensorboard = tensorboard self.show_last_only = show_last_only + self.pause_time = pause_time self.block = block + def setup_ax(self, xlabel=None, ylabel=None): + ax = self.fig.gca() + ax.cla() + ax.set_title(self.title) + ax.axis("off") + if xlabel: + ax.set_xlabel("Data dimension 1") + if ylabel: + ax.set_ylabel("Data dimension 2") + return ax + + def get_mesh_input(self, x): + x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border + y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border + xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / self.resolution), + np.arange(y_min, y_max, 1 / self.resolution)) + mesh_input = np.c_[xx.ravel(), yy.ravel()] + return mesh_input, xx, yy + def add_to_tensorboard(self, trainer, pl_module): tb = pl_module.logger.experiment tb.add_figure(tag=f"{self.title}", @@ -289,6 +319,14 @@ class Vis2DAbstract(pl.Callback): global_step=trainer.current_epoch, close=False) + def log_and_display(self, trainer, pl_module): + if self.tensorboard: + self.add_to_tensorboard(trainer, pl_module) + if not self.block: + plt.pause(self.pause_time) + else: + plt.show(block=True) + class VisGLVQ2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): @@ -298,12 +336,8 @@ class VisGLVQ2D(Vis2DAbstract): protos = pl_module.prototypes plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.axis("off") - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") + ax = self.setup_ax(xlabel="Data dimension 1", + ylabel="Data dimension 2") ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter( protos[:, 0], @@ -315,23 +349,15 @@ class VisGLVQ2D(Vis2DAbstract): s=50, ) x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] + mesh_input, xx, yy = self.get_mesh_input(x) y_pred = pl_module.predict(torch.Tensor(mesh_input)) y_pred = y_pred.reshape(xx.shape) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - if self.tensorboard: - self.add_to_tensorboard(trainer, pl_module) - if not self.block: - plt.pause(0.01) - else: - plt.show(block=True) + # ax.set_xlim(left=x_min + 0, right=x_max - 0) + # ax.set_ylim(bottom=y_min + 0, top=y_max - 0) + + self.log_and_display(trainer, pl_module) class VisSiameseGLVQ2D(Vis2DAbstract): @@ -341,10 +367,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract): x_train, y_train = self.x_train, self.y_train x_train = pl_module.backbone(torch.Tensor(x_train)).detach() protos = pl_module.backbone(torch.Tensor(protos)).detach() - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.axis("off") + ax = self.setup_ax() ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter( protos[:, 0], @@ -356,48 +379,54 @@ class VisSiameseGLVQ2D(Vis2DAbstract): s=50, ) x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border - y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] + mesh_input, xx, yy = self.get_mesh_input(x) y_pred = pl_module.predict_latent(torch.Tensor(mesh_input)) y_pred = y_pred.reshape(xx.shape) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - tb = pl_module.logger.experiment - tb.add_figure( - tag=f"{self.title}", - figure=self.fig, - global_step=trainer.current_epoch, - close=False, - ) + # ax.set_xlim(left=x_min + 0, right=x_max - 0) + # ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - if self.tensorboard: - self.add_to_tensorboard(trainer, pl_module) - if not self.block: - plt.pause(0.05) - else: - plt.show(block=True) + self.log_and_display(trainer, pl_module) + + +class VisCBC2D(Vis2DAbstract): + def on_epoch_end(self, trainer, pl_module): + x_train, y_train = self.x_train, self.y_train + protos = pl_module.components + ax = self.setup_ax(xlabel="Data dimension 1", + ylabel="Data dimension 2") + ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") + ax.scatter( + protos[:, 0], + protos[:, 1], + c="w", + cmap=self.cmap, + edgecolor="k", + marker="D", + s=50, + ) + x = np.vstack((x_train, protos)) + mesh_input, xx, yy = self.get_mesh_input(x) + y_pred = pl_module.predict(torch.Tensor(mesh_input)) + y_pred = y_pred.reshape(xx.shape) + + ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) + # ax.set_xlim(left=x_min + 0, right=x_max - 0) + # ax.set_ylim(bottom=y_min + 0, top=y_max - 0) + + self.log_and_display(trainer, pl_module) class VisNG2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): + x_train, y_train = self.x_train, self.y_train protos = pl_module.prototypes cmat = pl_module.topology_layer.cmat.cpu().numpy() - # Visualize the data and the prototypes - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(self.x_train[:, 0], - self.x_train[:, 1], - c=self.y_train, - edgecolor="k") + ax = self.setup_ax(xlabel="Data dimension 1", + ylabel="Data dimension 2") + ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter( protos[:, 0], protos[:, 1], @@ -417,9 +446,4 @@ class VisNG2D(Vis2DAbstract): "k-", ) - if self.tensorboard: - self.add_to_tensorboard(trainer, pl_module) - if not self.block: - plt.pause(0.01) - else: - plt.show(block=True) + self.log_and_display(trainer, pl_module) From 17315ff24253fc66906b7eadb00b20bdb2ca8e94 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:23:52 +0200 Subject: [PATCH 22/27] Add models to the prototorch.models namespace --- prototorch/models/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index 2507e45..bde1ef1 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -1,5 +1,10 @@ from importlib.metadata import PackageNotFoundError, version +from .cbc import CBC +from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ +from .neural_gas import NeuralGas +from .vis import * + VERSION_FALLBACK = "uninstalled_version" try: __version__ = version(__name__.replace(".", "-")) From d7972a69e8592c44916c5a5d7366d523b7820107 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:24:47 +0200 Subject: [PATCH 23/27] Update GMLVQ model --- prototorch/models/glvq.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 868d4c9..6fe76d3 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -191,6 +191,26 @@ class GMLVQ(GLVQ): self.hparams.latent_dim, bias=False) + @property + def omega_matrix(self): + return self.omega_layer.weight.detach().cpu() + + @property + def lambda_matrix(self): + omega = self.omega_layer.weight + lam = omega @ omega.T + return lam.detach().cpu() + + def show_lambda(self): + import matplotlib.pyplot as plt + title = "Lambda matrix" + plt.figure(title) + plt.title(title) + plt.imshow(self.lambda_matrix, cmap="gray") + plt.axis("off") + plt.colorbar() + plt.show(block=True) + def forward(self, x): protos, _ = self.proto_layer() latent_x = self.omega_layer(x) From 728131e9dbfc452fdf76a864a6c326a8f240911b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:25:04 +0200 Subject: [PATCH 24/27] Update example scripts --- examples/cbc_iris.py | 90 +++++------------------------------ examples/glvq_iris.py | 25 +++++----- examples/glvq_spiral.py | 23 ++++----- examples/gmlvq_iris.py | 37 +++++--------- examples/liramlvq_tecator.py | 34 +++++++------ examples/ng_iris.py | 36 +++++--------- examples/siamese_glvq_iris.py | 21 ++++---- 7 files changed, 83 insertions(+), 183 deletions(-) diff --git a/examples/cbc_iris.py b/examples/cbc_iris.py index 5497e6c..92f0791 100644 --- a/examples/cbc_iris.py +++ b/examples/cbc_iris.py @@ -1,102 +1,36 @@ """CBC example using the Iris dataset.""" -import numpy as np +import prototorch as pt import pytorch_lightning as pl import torch -from matplotlib import pyplot as plt -from prototorch.components import initializers as cinit -from prototorch.datasets.abstract import NumpyDataset -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader - -from prototorch.models.cbc import CBC, euclidean_similarity - - -class VisualizationCallback(pl.Callback): - def __init__( - self, - x_train, - y_train, - prototype_model=True, - title="Prototype Visualization", - cmap="viridis", - ): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - self.prototype_model = prototype_model - - def on_epoch_end(self, trainer, pl_module): - if self.prototype_model: - protos = pl_module.components - color = pl_module.prototype_labels - else: - protos = pl_module.components - color = "k" - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=color, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - if __name__ == "__main__": # Dataset + from sklearn.datasets import load_iris x_train, y_train = load_iris(return_X_y=True) x_train = x_train[:, [0, 2]] - train_ds = NumpyDataset(x_train, y_train) + train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) # Hyperparameters hparams = dict( input_dim=x_train.shape[1], - nclasses=len(np.unique(y_train)), + nclasses=3, num_components=9, - component_initializer=cinit.StratifiedMeanInitializer( - torch.Tensor(x_train), torch.Tensor(y_train)), + component_initializer=pt.components.SMI(train_ds), lr=0.01, ) # Initialize the model - model = CBC( - hparams, - data=[x_train, y_train], - similarity=euclidean_similarity, - ) + model = pt.models.CBC(hparams) # Callbacks - dvis = VisualizationCallback(x_train, - y_train, - prototype_model=False, - title="CBC Iris Example") + dvis = pt.models.VisCBC2D(data=(x_train, y_train), + title="CBC Iris Example") # Setup trainer trainer = pl.Trainer( @@ -107,4 +41,4 @@ if __name__ == "__main__": ) # Training loop - trainer.fit(model, train_loader) \ No newline at end of file + trainer.fit(model, train_loader) diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 3e698f7..95982e7 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -1,40 +1,39 @@ """GLVQ example using the Iris dataset.""" +import prototorch as pt import pytorch_lightning as pl import torch -from prototorch.components import initializers as cinit -from prototorch.datasets.abstract import NumpyDataset -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader - -from prototorch.models.callbacks.visualization import VisGLVQ2D -from prototorch.models.glvq import GLVQ if __name__ == "__main__": # Dataset + from sklearn.datasets import load_iris x_train, y_train = load_iris(return_X_y=True) x_train = x_train[:, [0, 2]] - train_ds = NumpyDataset(x_train, y_train) + train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) # Hyperparameters hparams = dict( nclasses=3, prototypes_per_class=2, - prototype_initializer=cinit.StratifiedMeanInitializer( - torch.Tensor(x_train), torch.Tensor(y_train)), + prototype_initializer=pt.components.SMI(train_ds), lr=0.01, ) # Initialize the model - model = GLVQ(hparams) + model = pt.models.GLVQ(hparams) + + # Callbacks + vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) # Setup trainer trainer = pl.Trainer( max_epochs=50, - callbacks=[VisGLVQ2D(x_train, y_train)], + callbacks=[vis], ) # Training loop diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py index 824bd0d..3fee454 100644 --- a/examples/glvq_spiral.py +++ b/examples/glvq_spiral.py @@ -1,14 +1,8 @@ """GLVQ example using the spiral dataset.""" +import prototorch as pt import pytorch_lightning as pl import torch -from prototorch.components import initializers as cinit -from prototorch.datasets.abstract import NumpyDataset -from prototorch.datasets.spiral import make_spiral -from torch.utils.data import DataLoader - -from prototorch.models.callbacks.visualization import VisGLVQ2D -from prototorch.models.glvq import GLVQ class StopOnNaN(pl.Callback): @@ -23,29 +17,28 @@ class StopOnNaN(pl.Callback): if __name__ == "__main__": # Dataset - x_train, y_train = make_spiral(n_samples=600, noise=0.6) - train_ds = NumpyDataset(x_train, y_train) + train_ds = pt.datasets.Spiral(n_samples=600, noise=0.6) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=256) + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=256) # Hyperparameters hparams = dict( nclasses=2, prototypes_per_class=20, - prototype_initializer=cinit.SSI(torch.Tensor(x_train), - torch.Tensor(y_train), - noise=1e-7), + prototype_initializer=pt.components.SSI(train_ds, noise=1e-7), transfer_function="sigmoid_beta", transfer_beta=10.0, lr=0.01, ) # Initialize the model - model = GLVQ(hparams) + model = pt.models.GLVQ(hparams) # Callbacks - vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True) + vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True) snan = StopOnNaN(model.proto_layer.components) # Setup trainer diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py index ce15e56..ba903e7 100644 --- a/examples/gmlvq_iris.py +++ b/examples/gmlvq_iris.py @@ -1,48 +1,37 @@ """GMLVQ example using all four dimensions of the Iris dataset.""" +import prototorch as pt import pytorch_lightning as pl import torch -from prototorch.components import initializers as cinit -from prototorch.datasets.abstract import NumpyDataset -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader - -from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D -from prototorch.models.glvq import GMLVQ if __name__ == "__main__": # Dataset + from sklearn.datasets import load_iris x_train, y_train = load_iris(return_X_y=True) - train_ds = NumpyDataset(x_train, y_train) + train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) # Hyperparameters hparams = dict( nclasses=3, prototypes_per_class=1, - prototype_initializer=cinit.SMI(torch.Tensor(x_train), - torch.Tensor(y_train)), input_dim=x_train.shape[1], - latent_dim=2, + latent_dim=x_train.shape[1], + prototype_initializer=pt.components.SMI(train_ds), lr=0.01, ) # Initialize the model - model = GMLVQ(hparams) - - # Model summary - print(model) - - # Callbacks - vis = VisSiameseGLVQ2D(x_train, y_train) - - # Namespace hook for the visualization to work - model.backbone = model.omega_layer + model = pt.models.GMLVQ(hparams) # Setup trainer - trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) + trainer = pl.Trainer(max_epochs=100) # Training loop trainer.fit(model, train_loader) + + # Display the Lambda matrix + model.show_lambda() diff --git a/examples/liramlvq_tecator.py b/examples/liramlvq_tecator.py index d9fff1e..b7cc21a 100644 --- a/examples/liramlvq_tecator.py +++ b/examples/liramlvq_tecator.py @@ -1,47 +1,45 @@ -"""Limited Rank MLVQ example using the Tecator dataset.""" +"""Limited Rank Matrix LVQ example using the Tecator dataset.""" +import prototorch as pt import pytorch_lightning as pl -from prototorch.components import initializers as cinit -from prototorch.datasets.tecator import Tecator -from torch.utils.data import DataLoader - -from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D -from prototorch.models.glvq import GMLVQ +import torch if __name__ == "__main__": # Dataset - train_ds = Tecator(root="./datasets/", train=True) + train_ds = pt.datasets.Tecator(root="~/datasets/", train=True) + + # Reproducibility + pl.utilities.seed.seed_everything(seed=42) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=32) - - # Grab the full dataset to warm-start prototypes - x, y = next(iter(DataLoader(train_ds, batch_size=len(train_ds)))) + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=32) # Hyperparameters hparams = dict( nclasses=2, prototypes_per_class=2, - prototype_initializer=cinit.SMI(x, y), - input_dim=x.shape[1], + input_dim=100, latent_dim=2, - lr=0.01, + prototype_initializer=pt.components.SMI(train_ds), + lr=0.001, ) # Initialize the model - model = GMLVQ(hparams) + model = pt.models.GMLVQ(hparams) # Model summary print(model) # Callbacks - vis = VisSiameseGLVQ2D(x, y) + vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1) # Namespace hook for the visualization to work model.backbone = model.omega_layer # Setup trainer - trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) + trainer = pl.Trainer(max_epochs=200, callbacks=[vis]) # Training loop trainer.fit(model, train_loader) diff --git a/examples/ng_iris.py b/examples/ng_iris.py index 8b3f51f..38e6162 100644 --- a/examples/ng_iris.py +++ b/examples/ng_iris.py @@ -1,50 +1,40 @@ """Neural Gas example using the Iris dataset.""" +import prototorch as pt import pytorch_lightning as pl -from prototorch.datasets.abstract import NumpyDataset -from sklearn.datasets import load_iris -from sklearn.preprocessing import StandardScaler -from torch.utils.data import DataLoader - -from prototorch.models.callbacks.visualization import VisNG2D -from prototorch.models.neural_gas import NeuralGas +import torch if __name__ == "__main__": - # Dataset + # Prepare and pre-process the dataset + from sklearn.datasets import load_iris + from sklearn.preprocessing import StandardScaler x_train, y_train = load_iris(return_X_y=True) x_train = x_train[:, [0, 2]] scaler = StandardScaler() scaler.fit(x_train) x_train = scaler.transform(x_train) - train_ds = NumpyDataset(x_train, y_train) + train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) # Hyperparameters - hparams = dict( - input_dim=x_train.shape[1], - num_prototypes=30, - lr=0.01, - ) + hparams = dict(num_prototypes=30, lr=0.03) # Initialize the model - model = NeuralGas(hparams) + model = pt.models.NeuralGas(hparams) # Model summary print(model) # Callbacks - vis = VisNG2D(x_train, y_train) + vis = pt.models.VisNG2D(data=train_ds) # Setup trainer - trainer = pl.Trainer( - max_epochs=100, - callbacks=[ - vis, - ], - ) + trainer = pl.Trainer(max_epochs=200, callbacks=[vis]) # Training loop trainer.fit(model, train_loader) diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index d117f4f..a6390b2 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -1,13 +1,8 @@ """Siamese GLVQ example using all four dimensions of the Iris dataset.""" +import prototorch as pt import pytorch_lightning as pl import torch -from prototorch.components import initializers as cinit -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D -from prototorch.models.glvq import SiameseGLVQ -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader class Backbone(torch.nn.Module): @@ -29,27 +24,29 @@ class Backbone(torch.nn.Module): if __name__ == "__main__": # Dataset + from sklearn.datasets import load_iris x_train, y_train = load_iris(return_X_y=True) - train_ds = NumpyDataset(x_train, y_train) + train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Reproducibility pl.utilities.seed.seed_everything(seed=2) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) # Hyperparameters hparams = dict( nclasses=3, prototypes_per_class=2, - prototype_initializer=cinit.SMI(torch.Tensor(x_train), - torch.Tensor(y_train)), + prototype_initializer=pt.components.SMI((x_train, y_train)), proto_lr=0.001, bb_lr=0.001, ) # Initialize the model - model = SiameseGLVQ( + model = pt.models.SiameseGLVQ( hparams, backbone_module=Backbone, ) @@ -58,7 +55,7 @@ if __name__ == "__main__": print(model) # Callbacks - vis = VisSiameseGLVQ2D(x_train, y_train, border=0.1) + vis = pt.models.VisSiameseGLVQ2D(data=(x_train, y_train), border=0.1) # Setup trainer trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) From 63a5a98491f52248af9368a39466eea0cd9759d6 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:41:53 +0200 Subject: [PATCH 25/27] Update readme --- README.md | 53 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index e16667f..38e9258 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,15 @@ PyTorch-Lightning. ## Installation -To install this plugin, simple install -[ProtoTorch](https://github.com/si-cim/prototorch) first by following the -installation instructions there and then install this plugin by doing: +To install this plugin, first install +[ProtoTorch](https://github.com/si-cim/prototorch) with: + +```sh +git clone https://github.com/si-cim/prototorch.git && cd prototorch +pip install -e . +``` + +and then install the plugin itself with: ```sh git clone https://github.com/si-cim/prototorch_models.git && cd prototorch_models @@ -28,9 +34,14 @@ following: ```sh export WORKON_HOME=~/pyenvs mkdir -p $WORKON_HOME -source /usr/local/bin/virtualenvwrapper.sh # might be different -# source ~/.local/bin/virtualenvwrapper.sh +source /usr/local/bin/virtualenvwrapper.sh # location may vary mkvirtualenv pt +``` + +Once you have a virtual environment setup, you can start install the `models` +plugin with: + +```sh workon pt git clone git@github.com:si-cim/prototorch_models.git cd prototorch_models @@ -43,23 +54,31 @@ To assist in the development process, you may also find it useful to install ## Available models -- GLVQ +- Generalized Learning Vector Quantization (GLVQ) +- Generalized Matrix Learning Vector Quantization (GMLVQ) +- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ) - Siamese GLVQ -- Neural Gas -- GMLVQ -- Limited-Rank GMLVQ +- Neural Gas (NG) ## Work in Progress -- CBC -- LVQMLN +- Classification-By-Components Network (CBC) +- Learning Vector Quantization Multi-Layer Network (LVQMLN) ## Planned models - Local-Matrix GMLVQ -- GTLVQ -- RSLVQ -- PLVQ -- SILVQ -- KNN -- LVQ1 +- Generalized Tangent Learning Vector Quantization (GTLVQ) +- Robust Soft Learning Vector Quantization (RSLVQ) +- Probabilistic Learning Vector Quantization (PLVQ) +- Self-Incremental Learning Vector Quantization (SILVQ) +- K-Nearest Neighbors (KNN) +- Learning Vector Quantization 1 (LVQ1) + +## FAQ + +### How do I update the plugin? + +If you have already cloned and installed `prototorch` and the +`prototorch_models` plugin with the `-e` flag via `pip`, all you have to do is +navigate to those folders from your terminal and do `git pull` to update. From 11b3e53ecb61402876bfd99ebf09bfbdc433c59b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:45:37 +0200 Subject: [PATCH 26/27] Return prototypes as torch tensor --- prototorch/models/abstract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index dcc89a9..562b6d3 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -20,4 +20,4 @@ class AbstractLightningModel(pl.LightningModule): class AbstractPrototypeModel(AbstractLightningModel): @property def prototypes(self): - return self.proto_layer.components.detach().numpy() + return self.proto_layer.components.detach().cpu() From dd75fbfff8c2309a3a6405a4d62cbe8f3d4d214a Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:46:09 +0200 Subject: [PATCH 27/27] Make cbc example reproducible --- examples/cbc_iris.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/cbc_iris.py b/examples/cbc_iris.py index 92f0791..ae5c5ae 100644 --- a/examples/cbc_iris.py +++ b/examples/cbc_iris.py @@ -11,6 +11,9 @@ if __name__ == "__main__": x_train = x_train[:, [0, 2]] train_ds = pt.datasets.NumpyDataset(x_train, y_train) + # Reproducibility + pl.utilities.seed.seed_everything(seed=2) + # Dataloaders train_loader = torch.utils.data.DataLoader(train_ds, num_workers=0, @@ -20,8 +23,8 @@ if __name__ == "__main__": hparams = dict( input_dim=x_train.shape[1], nclasses=3, - num_components=9, - component_initializer=pt.components.SMI(train_ds), + num_components=5, + component_initializer=pt.components.SSI(train_ds, noise=0.01), lr=0.01, ) @@ -34,7 +37,7 @@ if __name__ == "__main__": # Setup trainer trainer = pl.Trainer( - max_epochs=50, + max_epochs=200, callbacks=[ dvis, ],