From a16bebd0c4d9d5ce2de6c6727bb7ad46cf7a5ebb Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 29 Apr 2021 17:05:41 +0200 Subject: [PATCH] Use Components instead of Prototypes and refactor old examples --- README.md | 2 +- examples/glvq_iris_v1.py | 76 ++------- examples/ng_iris.py | 67 +------- examples/siamese_glvq_iris.py | 78 ++------- prototorch/models/callbacks/visualization.py | 159 ++++++++++++++++++- prototorch/models/glvq.py | 38 ++--- prototorch/models/neural_gas.py | 31 ++-- 7 files changed, 216 insertions(+), 235 deletions(-) diff --git a/README.md b/README.md index 4078447..6b91b99 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ workon pt git clone git@github.com:si-cim/prototorch_models.git cd prototorch_models git checkout dev -pip install -e .[all] # \[all\] if you are using zsh +pip install -e .[all] # \[all\] if you are using zsh or MacOS ``` To assist in the development process, you may also find it useful to install diff --git a/examples/glvq_iris_v1.py b/examples/glvq_iris_v1.py index 93db0f1..aff2d26 100644 --- a/examples/glvq_iris_v1.py +++ b/examples/glvq_iris_v1.py @@ -1,63 +1,14 @@ """GLVQ example using the Iris dataset.""" -import numpy as np import pytorch_lightning as pl import torch -from matplotlib import pyplot as plt +from prototorch.components import initializers as cinit +from prototorch.datasets.abstract import NumpyDataset +from prototorch.models.callbacks.visualization import VisGLVQ2D +from prototorch.models.glvq import GLVQ from sklearn.datasets import load_iris from torch.utils.data import DataLoader -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.glvq import GLVQ - - -class VisualizationCallback(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Prototype Visualization", - cmap="viridis"): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - - def on_epoch_end(self, trainer, pl_module): - protos = pl_module.prototypes - plabels = pl_module.prototype_labels - x_train, y_train = self.x_train, self.y_train - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=plabels, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - - if __name__ == "__main__": # Dataset x_train, y_train = load_iris(return_X_y=True) @@ -69,24 +20,21 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( - input_dim=x_train.shape[1], nclasses=3, - prototypes_per_class=3, - prototype_initializer="stratified_mean", - lr=0.1, + prototypes_per_class=2, + prototype_initializer=cinit.StratifiedMeanInitializer( + torch.Tensor(x_train), torch.Tensor(y_train)), + lr=0.01, ) # Initialize the model model = GLVQ(hparams, data=[x_train, y_train]) - # Model summary - print(model) - - # Callbacks - vis = VisualizationCallback(x_train, y_train) - # Setup trainer - trainer = pl.Trainer(max_epochs=50, callbacks=[vis]) + trainer = pl.Trainer( + max_epochs=50, + callbacks=[VisGLVQ2D(x_train, y_train)], + ) # Training loop trainer.fit(model, train_loader) diff --git a/examples/ng_iris.py b/examples/ng_iris.py index 022e9a2..16e954a 100644 --- a/examples/ng_iris.py +++ b/examples/ng_iris.py @@ -3,63 +3,13 @@ import numpy as np import pytorch_lightning as pl from matplotlib import pyplot as plt +from prototorch.datasets.abstract import NumpyDataset +from prototorch.models.callbacks.visualization import VisNG2D +from prototorch.models.neural_gas import NeuralGas from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler from torch.utils.data import DataLoader -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.neural_gas import NeuralGas - - -class VisualizationCallback(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Neural Gas Visualization", - cmap="viridis"): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - - def on_epoch_end(self, trainer, pl_module: NeuralGas): - protos = pl_module.proto_layer.prototypes.detach().cpu().numpy() - cmat = pl_module.topology_layer.cmat.cpu().numpy() - - # Visualize the data and the prototypes - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(self.x_train[:, 0], - self.x_train[:, 1], - c=self.y_train, - edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c="k", - edgecolor="k", - marker="D", - s=50, - ) - - # Draw connections - for i in range(len(protos)): - for j in range(len(protos)): - if cmat[i][j]: - ax.plot( - [protos[i, 0], protos[j, 0]], - [protos[i, 1], protos[j, 1]], - "k-", - ) - - plt.pause(0.01) - - if __name__ == "__main__": # Dataset x_train, y_train = load_iris(return_X_y=True) @@ -68,7 +18,6 @@ if __name__ == "__main__": scaler.fit(x_train) x_train = scaler.transform(x_train) - y_single_class = np.zeros_like(y_train) train_ds = NumpyDataset(x_train, y_train) # Dataloaders @@ -77,20 +26,18 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( input_dim=x_train.shape[1], - nclasses=1, - prototypes_per_class=30, - prototype_initializer="rand", - lr=0.1, + num_prototypes=30, + lr=0.01, ) # Initialize the model - model = NeuralGas(hparams, data=[x_train, y_single_class]) + model = NeuralGas(hparams) # Model summary print(model) # Callbacks - vis = VisualizationCallback(x_train, y_train) + vis = VisNG2D(x_train, y_train) # Setup trainer trainer = pl.Trainer( diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index edf18d5..897e3f0 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -1,70 +1,15 @@ """Siamese GLVQ example using all four dimensions of the Iris dataset.""" -import numpy as np import pytorch_lightning as pl import torch -from matplotlib import pyplot as plt +from prototorch.components import (StratifiedMeanInitializer, + StratifiedSelectionInitializer) +from prototorch.datasets.abstract import NumpyDataset +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import SiameseGLVQ from sklearn.datasets import load_iris from torch.utils.data import DataLoader -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.glvq import SiameseGLVQ - - -class VisualizationCallback(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Prototype Visualization", - cmap="viridis"): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - - def on_epoch_end(self, trainer, pl_module): - protos = pl_module.prototypes - plabels = pl_module.prototype_labels - x_train, y_train = self.x_train, self.y_train - x_train = pl_module.backbone(torch.Tensor(x_train)).detach() - protos = pl_module.backbone(torch.Tensor(protos)).detach() - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.axis("off") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=plabels, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 0.2, x[:, 0].max() + 0.2 - y_min, y_max = x[:, 1].min() - 0.2, x[:, 1].max() + 0.2 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict_latent(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - tb = pl_module.logger.experiment - tb.add_figure( - tag=f"{self.title}", - figure=self.fig, - global_step=trainer.current_epoch, - close=False, - ) - plt.pause(0.1) - class Backbone(torch.nn.Module): def __init__(self, input_size=4, hidden_size=10, latent_size=2): @@ -90,23 +35,24 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( - input_dim=x_train.shape[1], nclasses=3, prototypes_per_class=1, - prototype_initializer="stratified_mean", + prototype_initializer=StratifiedMeanInitializer( + torch.Tensor(x_train), torch.Tensor(y_train)), lr=0.01, ) # Initialize the model - model = SiameseGLVQ(hparams, - backbone_module=Backbone, - data=[x_train, y_train]) + model = SiameseGLVQ( + hparams, + backbone_module=Backbone, + ) # Model summary print(model) # Callbacks - vis = VisualizationCallback(x_train, y_train) + vis = VisSiameseGLVQ2D(x_train, y_train) # Setup trainer trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) diff --git a/prototorch/models/callbacks/visualization.py b/prototorch/models/callbacks/visualization.py index 9d16c64..11bc729 100644 --- a/prototorch/models/callbacks/visualization.py +++ b/prototorch/models/callbacks/visualization.py @@ -1,16 +1,17 @@ import os import numpy as np +import pytorch_lightning as pl import torch from matplotlib import pyplot as plt from matplotlib.offsetbox import AnchoredText - from prototorch.utils.celluloid import Camera from prototorch.utils.colors import color_scheme -from prototorch.utils.utils import gif_from_dir, make_directory, prettify_string +from prototorch.utils.utils import (gif_from_dir, make_directory, + prettify_string) -class VisWeights(Callback): +class VisWeights(pl.Callback): """Abstract weight visualization callback.""" def __init__( self, @@ -258,3 +259,155 @@ class VisPointProtos(VisWeights): epoch = trainer.current_epoch self._display_logs(self.ax, epoch, logs) self._show_and_save(epoch) + + +class VisGLVQ2D(pl.Callback): + def __init__(self, + x_train, + y_train, + title="Prototype Visualization", + cmap="viridis"): + super().__init__() + self.x_train = x_train + self.y_train = y_train + self.title = title + self.fig = plt.figure(self.title) + self.cmap = cmap + + def on_epoch_end(self, trainer, pl_module): + protos = pl_module.prototypes + plabels = pl_module.prototype_labels + x_train, y_train = self.x_train, self.y_train + ax = self.fig.gca() + ax.cla() + ax.set_title(self.title) + ax.axis("off") + ax.set_xlabel("Data dimension 1") + ax.set_ylabel("Data dimension 2") + ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") + ax.scatter( + protos[:, 0], + protos[:, 1], + c=plabels, + cmap=self.cmap, + edgecolor="k", + marker="D", + s=50, + ) + x = np.vstack((x_train, protos)) + x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 + y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 + xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), + np.arange(y_min, y_max, 1 / 50)) + mesh_input = np.c_[xx.ravel(), yy.ravel()] + y_pred = pl_module.predict(torch.Tensor(mesh_input)) + y_pred = y_pred.reshape(xx.shape) + + ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) + ax.set_xlim(left=x_min + 0, right=x_max - 0) + ax.set_ylim(bottom=y_min + 0, top=y_max - 0) + plt.pause(0.1) + + +class VisSiameseGLVQ2D(pl.Callback): + def __init__(self, + x_train, + y_train, + title="Prototype Visualization", + cmap="viridis"): + super().__init__() + self.x_train = x_train + self.y_train = y_train + self.title = title + self.fig = plt.figure(self.title) + self.cmap = cmap + + def on_epoch_end(self, trainer, pl_module): + protos = pl_module.prototypes + plabels = pl_module.prototype_labels + x_train, y_train = self.x_train, self.y_train + x_train = pl_module.backbone(torch.Tensor(x_train)).detach() + protos = pl_module.backbone(torch.Tensor(protos)).detach() + ax = self.fig.gca() + ax.cla() + ax.set_title(self.title) + ax.axis("off") + ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") + ax.scatter( + protos[:, 0], + protos[:, 1], + c=plabels, + cmap=self.cmap, + edgecolor="k", + marker="D", + s=50, + ) + x = np.vstack((x_train, protos)) + x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 + y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 + xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), + np.arange(y_min, y_max, 1 / 50)) + mesh_input = np.c_[xx.ravel(), yy.ravel()] + y_pred = pl_module.predict_latent(torch.Tensor(mesh_input)) + y_pred = y_pred.reshape(xx.shape) + + ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) + ax.set_xlim(left=x_min + 0, right=x_max - 0) + ax.set_ylim(bottom=y_min + 0, top=y_max - 0) + tb = pl_module.logger.experiment + tb.add_figure( + tag=f"{self.title}", + figure=self.fig, + global_step=trainer.current_epoch, + close=False, + ) + plt.pause(0.1) + + +class VisNG2D(pl.Callback): + def __init__(self, + x_train, + y_train, + title="Neural Gas Visualization", + cmap="viridis"): + super().__init__() + self.x_train = x_train + self.y_train = y_train + self.title = title + self.fig = plt.figure(self.title) + self.cmap = cmap + + def on_epoch_end(self, trainer, pl_module): + protos = pl_module.prototypes + cmat = pl_module.topology_layer.cmat.cpu().numpy() + + # Visualize the data and the prototypes + ax = self.fig.gca() + ax.cla() + ax.set_title(self.title) + ax.set_xlabel("Data dimension 1") + ax.set_ylabel("Data dimension 2") + ax.scatter(self.x_train[:, 0], + self.x_train[:, 1], + c=self.y_train, + edgecolor="k") + ax.scatter( + protos[:, 0], + protos[:, 1], + c="k", + edgecolor="k", + marker="D", + s=50, + ) + + # Draw connections + for i in range(len(protos)): + for j in range(len(protos)): + if cmat[i][j]: + ax.plot( + [protos[i, 0], protos[j, 0]], + [protos[i, 1], protos[j, 1]], + "k-", + ) + + plt.pause(0.01) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 3ecd621..61749d7 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -1,14 +1,16 @@ import pytorch_lightning as pl import torch import torchmetrics - +from prototorch.components import LabeledComponents from prototorch.functions.competitions import wtac from prototorch.functions.distances import euclidean_distance from prototorch.functions.losses import glvq_loss from prototorch.modules.prototypes import Prototypes1D +from .abstract import AbstractPrototypeModel -class GLVQ(pl.LightningModule): + +class GLVQ(AbstractPrototypeModel): """Generalized Learning Vector Quantization.""" def __init__(self, hparams, **kwargs): super().__init__() @@ -18,29 +20,18 @@ class GLVQ(pl.LightningModule): # Default Values self.hparams.setdefault("distance", euclidean_distance) - self.proto_layer = Prototypes1D( - input_dim=self.hparams.input_dim, - nclasses=self.hparams.nclasses, - prototypes_per_class=self.hparams.prototypes_per_class, - prototype_initializer=self.hparams.prototype_initializer, - **kwargs) + self.proto_layer = LabeledComponents( + labels=(self.hparams.nclasses, self.hparams.prototypes_per_class), + initializer=self.hparams.prototype_initializer) self.train_acc = torchmetrics.Accuracy() - @property - def prototypes(self): - return self.proto_layer.prototypes.detach().numpy() - @property def prototype_labels(self): - return self.proto_layer.prototype_labels.detach().numpy() - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) - return optimizer + return self.proto_layer.component_labels.detach().numpy() def forward(self, x): - protos = self.proto_layer.prototypes + protos, _ = self.proto_layer() dis = self.hparams.distance(x, protos) return dis @@ -48,7 +39,7 @@ class GLVQ(pl.LightningModule): x, y = train_batch x = x.view(x.size(0), -1) dis = self(x) - plabels = self.proto_layer.prototype_labels + plabels = self.proto_layer.component_labels mu = glvq_loss(dis, y, prototype_labels=plabels) loss = mu.sum(dim=0) self.log("train_loss", loss) @@ -77,7 +68,7 @@ class GLVQ(pl.LightningModule): # model.eval() # ?! with torch.no_grad(): d = self(x) - plabels = self.proto_layer.prototype_labels + plabels = self.proto_layer.component_labels y_pred = wtac(d, plabels) return y_pred.numpy() @@ -89,7 +80,7 @@ class ImageGLVQ(GLVQ): clamping after updates. """ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.proto_layer.prototypes.data.clamp_(0.0, 1.0) + self.proto_layer.components.data.clamp_(0.0, 1.0) class SiameseGLVQ(GLVQ): @@ -115,7 +106,7 @@ class SiameseGLVQ(GLVQ): def forward(self, x): self.sync_backbones() - protos = self.proto_layer.prototypes + protos, _ = self.proto_layer() latent_x = self.backbone(x) latent_protos = self.backbone_dependent(protos) @@ -126,9 +117,8 @@ class SiameseGLVQ(GLVQ): def predict_latent(self, x): # model.eval() # ?! with torch.no_grad(): - protos = self.proto_layer.prototypes + protos, plabels = self.proto_layer() latent_protos = self.backbone_dependent(protos) d = euclidean_distance(x, latent_protos) - plabels = self.proto_layer.prototype_labels y_pred = wtac(d, plabels) return y_pred.numpy() diff --git a/prototorch/models/neural_gas.py b/prototorch/models/neural_gas.py index 98bb5e7..d98bfa0 100644 --- a/prototorch/models/neural_gas.py +++ b/prototorch/models/neural_gas.py @@ -1,10 +1,13 @@ import pytorch_lightning as pl import torch - +from prototorch.components import Components +from prototorch.components import initializers as cinit from prototorch.functions.distances import euclidean_distance from prototorch.modules import Prototypes1D from prototorch.modules.losses import NeuralGasEnergy +from .abstract import AbstractPrototypeModel + class EuclideanDistance(torch.nn.Module): def forward(self, x, y): @@ -34,41 +37,35 @@ class ConnectionTopology(torch.nn.Module): return f"agelimit: {self.agelimit}" -class NeuralGas(pl.LightningModule): +class NeuralGas(AbstractPrototypeModel): def __init__(self, hparams, **kwargs): super().__init__() self.save_hyperparameters(hparams) # Default Values + self.hparams.setdefault("input_dim", 2) self.hparams.setdefault("agelimit", 10) self.hparams.setdefault("lm", 1) - self.hparams.setdefault("prototype_initializer", "zeros") + self.hparams.setdefault("prototype_initializer", + cinit.ZerosInitializer(self.hparams.input_dim)) - self.proto_layer = Prototypes1D( - input_dim=self.hparams.input_dim, - nclasses=self.hparams.nclasses, - prototypes_per_class=self.hparams.prototypes_per_class, - prototype_initializer=self.hparams.prototype_initializer, - **kwargs, - ) + self.proto_layer = Components( + self.hparams.num_prototypes, + initializer=self.hparams.prototype_initializer) self.distance_layer = EuclideanDistance() self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm) self.topology_layer = ConnectionTopology( agelimit=self.hparams.agelimit, - num_prototypes=len(self.proto_layer.prototypes), + num_prototypes=self.hparams.num_prototypes, ) def training_step(self, train_batch, batch_idx): - x, _ = train_batch - protos, _ = self.proto_layer() + x = train_batch[0] + protos = self.proto_layer() d = self.distance_layer(x, protos) cost, order = self.energy_layer(d) self.topology_layer(d) return cost - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) - return optimizer