diff --git a/README.md b/README.md index 6b91b99..38e9258 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,15 @@ PyTorch-Lightning. ## Installation -To install this plugin, simple install -[ProtoTorch](https://github.com/si-cim/prototorch) first by following the -installation instructions there and then install this plugin by doing: +To install this plugin, first install +[ProtoTorch](https://github.com/si-cim/prototorch) with: + +```sh +git clone https://github.com/si-cim/prototorch.git && cd prototorch +pip install -e . +``` + +and then install the plugin itself with: ```sh git clone https://github.com/si-cim/prototorch_models.git && cd prototorch_models @@ -28,9 +34,14 @@ following: ```sh export WORKON_HOME=~/pyenvs mkdir -p $WORKON_HOME -source /usr/local/bin/virtualenvwrapper.sh # might be different -# source ~/.local/bin/virtualenvwrapper.sh +source /usr/local/bin/virtualenvwrapper.sh # location may vary mkvirtualenv pt +``` + +Once you have a virtual environment setup, you can start install the `models` +plugin with: + +```sh workon pt git clone git@github.com:si-cim/prototorch_models.git cd prototorch_models @@ -43,18 +54,31 @@ To assist in the development process, you may also find it useful to install ## Available models -- GLVQ +- Generalized Learning Vector Quantization (GLVQ) +- Generalized Matrix Learning Vector Quantization (GMLVQ) +- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ) - Siamese GLVQ -- Neural Gas +- Neural Gas (NG) ## Work in Progress -- CBC + +- Classification-By-Components Network (CBC) +- Learning Vector Quantization Multi-Layer Network (LVQMLN) ## Planned models -- GMLVQ + - Local-Matrix GMLVQ -- Limited-Rank GMLVQ -- GTLVQ -- RSLVQ -- PLVQ -- LVQMLN +- Generalized Tangent Learning Vector Quantization (GTLVQ) +- Robust Soft Learning Vector Quantization (RSLVQ) +- Probabilistic Learning Vector Quantization (PLVQ) +- Self-Incremental Learning Vector Quantization (SILVQ) +- K-Nearest Neighbors (KNN) +- Learning Vector Quantization 1 (LVQ1) + +## FAQ + +### How do I update the plugin? + +If you have already cloned and installed `prototorch` and the +`prototorch_models` plugin with the `-e` flag via `pip`, all you have to do is +navigate to those folders from your terminal and do `git pull` to update. diff --git a/examples/cbc_circle.py b/examples/cbc_circle.py deleted file mode 100644 index 41702f7..0000000 --- a/examples/cbc_circle.py +++ /dev/null @@ -1,129 +0,0 @@ -"""CBC example using the Iris dataset.""" - -import numpy as np -import pytorch_lightning as pl -import torch -from matplotlib import pyplot as plt -from sklearn.datasets import make_circles -from torch.utils.data import DataLoader - -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.callbacks.visualization import VisPointProtos -from prototorch.models.cbc import CBC, euclidean_similarity -from prototorch.models.glvq import GLVQ - - -class VisualizationCallback(pl.Callback): - def __init__( - self, - x_train, - y_train, - prototype_model=True, - title="Prototype Visualization", - cmap="viridis", - ): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - self.prototype_model = prototype_model - - def on_epoch_end(self, trainer, pl_module): - if self.prototype_model: - protos = pl_module.prototypes - color = pl_module.prototype_labels - else: - protos = pl_module.components - color = "k" - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=color, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - - -if __name__ == "__main__": - # Dataset - x_train, y_train = make_circles(n_samples=300, - shuffle=True, - noise=0.05, - random_state=None, - factor=0.5) - train_ds = NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - - # Hyperparameters - hparams = dict( - input_dim=x_train.shape[1], - nclasses=len(np.unique(y_train)), - prototypes_per_class=5, - prototype_initializer="randn", - lr=0.01, - ) - - # Initialize the model - model = CBC( - hparams, - data=[x_train, y_train], - similarity=euclidean_similarity, - ) - - model = GLVQ(hparams, data=[x_train, y_train]) - - # Fix the component locations - # model.proto_layer.requires_grad_(False) - - # import sys - # sys.exit() - - # Model summary - print(model) - - # Callbacks - dvis = VisPointProtos( - data=(x_train, y_train), - save=True, - snap=False, - voronoi=True, - resolution=50, - pause_time=0.1, - make_gif=True, - ) - - # Setup trainer - trainer = pl.Trainer( - max_epochs=10, - callbacks=[ - dvis, - ], - ) - - # Training loop - trainer.fit(model, train_loader) diff --git a/examples/cbc_iris.py b/examples/cbc_iris.py index e5eab9e..ae5c5ae 100644 --- a/examples/cbc_iris.py +++ b/examples/cbc_iris.py @@ -1,112 +1,45 @@ """CBC example using the Iris dataset.""" -import numpy as np +import prototorch as pt import pytorch_lightning as pl import torch -from matplotlib import pyplot as plt -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader - -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.cbc import CBC - - -class VisualizationCallback(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Prototype Visualization", - cmap="viridis"): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - - def on_epoch_end(self, trainer, pl_module): - # protos = pl_module.prototypes - protos = pl_module.components - # plabels = pl_module.prototype_labels - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - # c=plabels, - c="k", - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - if __name__ == "__main__": # Dataset + from sklearn.datasets import load_iris x_train, y_train = load_iris(return_X_y=True) x_train = x_train[:, [0, 2]] - train_ds = NumpyDataset(x_train, y_train) + train_ds = pt.datasets.NumpyDataset(x_train, y_train) + + # Reproducibility + pl.utilities.seed.seed_everything(seed=2) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) # Hyperparameters hparams = dict( input_dim=x_train.shape[1], nclasses=3, - prototypes_per_class=3, - prototype_initializer="stratified_mean", + num_components=5, + component_initializer=pt.components.SSI(train_ds, noise=0.01), lr=0.01, ) # Initialize the model - model = CBC(hparams, data=[x_train, y_train]) - - # Fix the component locations - # model.proto_layer.requires_grad_(False) - - # Pure-positive reasonings - ncomps = 3 - nclasses = 3 - rmat = torch.stack( - [0.9 * torch.eye(ncomps), - torch.zeros(ncomps, nclasses)], dim=0) - # model.reasoning_layer.load_state_dict({"reasoning_probabilities": rmat}, - # strict=True) - - print(model.reasoning_layer.reasoning_probabilities) - # import sys - # sys.exit() - - # Model summary - print(model) + model = pt.models.CBC(hparams) # Callbacks - vis = VisualizationCallback(x_train, y_train) + dvis = pt.models.VisCBC2D(data=(x_train, y_train), + title="CBC Iris Example") # Setup trainer trainer = pl.Trainer( - max_epochs=100, + max_epochs=200, callbacks=[ - vis, + dvis, ], ) diff --git a/examples/cbc_mnist.py b/examples/cbc_mnist.py deleted file mode 100644 index 8546025..0000000 --- a/examples/cbc_mnist.py +++ /dev/null @@ -1,128 +0,0 @@ -"""CBC example using the MNIST dataset. - -This script also shows how to use Tensorboard for visualizing the prototypes. -""" - -import argparse - -import pytorch_lightning as pl -import torchvision -from torch.utils.data import DataLoader -from torchvision import transforms -from torchvision.datasets import MNIST - -from prototorch.models.cbc import CBC, ImageCBC, euclidean_similarity - - -class VisualizationCallback(pl.Callback): - def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2): - super().__init__() - self.to_shape = to_shape - self.nrow = nrow - - def on_epoch_end(self, trainer, pl_module: ImageCBC): - tb = pl_module.logger.experiment - - # components - components = pl_module.components - components_img = components.reshape(self.to_shape) - grid = torchvision.utils.make_grid(components_img, nrow=self.nrow) - tb.add_image( - tag="MNIST Components", - img_tensor=grid, - global_step=trainer.current_epoch, - dataformats="CHW", - ) - # Reasonings - reasonings = pl_module.reasonings - tb.add_images( - tag="MNIST Reasoning", - img_tensor=reasonings, - global_step=trainer.current_epoch, - dataformats="NCHW", - ) - - -if __name__ == "__main__": - # Arguments - parser = argparse.ArgumentParser() - parser.add_argument("--epochs", - type=int, - default=10, - help="Epochs to train.") - parser.add_argument("--lr", - type=float, - default=0.001, - help="Learning rate.") - parser.add_argument("--batch_size", - type=int, - default=256, - help="Batch size.") - parser.add_argument("--gpus", - type=int, - default=0, - help="Number of GPUs to use.") - parser.add_argument("--ppc", - type=int, - default=1, - help="Prototypes-Per-Class.") - args = parser.parse_args() - - # Dataset - mnist_train = MNIST( - "./datasets", - train=True, - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ]), - ) - mnist_test = MNIST( - "./datasets", - train=False, - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ]), - ) - - # Dataloaders - train_loader = DataLoader(mnist_train, batch_size=32) - test_loader = DataLoader(mnist_test, batch_size=32) - - # Grab the full dataset to warm-start prototypes - x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train)))) - x = x.view(len(mnist_train), -1) - - # Hyperparameters - hparams = dict( - input_dim=28 * 28, - nclasses=10, - prototypes_per_class=args.ppc, - prototype_initializer="randn", - lr=0.01, - similarity=euclidean_similarity, - ) - - # Initialize the model - model = CBC(hparams, data=[x, y]) - # Model summary - print(model) - - # Callbacks - vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc) - - # Setup trainer - trainer = pl.Trainer( - gpus=args.gpus, # change to use GPUs for training - max_epochs=args.epochs, - callbacks=[vis], - track_grad_norm=2, - # accelerator="ddp_cpu", # DEBUG-ONLY - # num_processes=2, # DEBUG-ONLY - ) - - # Training loop - trainer.fit(model, train_loader, test_loader) diff --git a/examples/cbc_spiral.py b/examples/cbc_spiral.py deleted file mode 100644 index 8147ba6..0000000 --- a/examples/cbc_spiral.py +++ /dev/null @@ -1,135 +0,0 @@ -"""CBC example using the Iris dataset.""" - -import numpy as np -import pytorch_lightning as pl -import torch -from matplotlib import pyplot as plt -from torch.utils.data import DataLoader - -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.cbc import CBC - - -class VisualizationCallback(pl.Callback): - def __init__( - self, - x_train, - y_train, - prototype_model=True, - title="Prototype Visualization", - cmap="viridis", - ): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - self.prototype_model = prototype_model - - def on_epoch_end(self, trainer, pl_module): - if self.prototype_model: - protos = pl_module.prototypes - color = pl_module.prototype_labels - else: - protos = pl_module.components - color = "k" - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=color, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - - -def make_spirals(n_samples=500, noise=0.3): - def get_samples(n, delta_t): - points = [] - for i in range(n): - r = i / n_samples * 5 - t = 1.75 * i / n * 2 * np.pi + delta_t - x = r * np.sin(t) + np.random.rand(1) * noise - y = r * np.cos(t) + np.random.rand(1) * noise - points.append([x, y]) - return points - - n = n_samples // 2 - positive = get_samples(n=n, delta_t=0) - negative = get_samples(n=n, delta_t=np.pi) - x = np.concatenate( - [np.array(positive).reshape(n, -1), - np.array(negative).reshape(n, -1)], - axis=0) - y = np.concatenate([np.zeros(n), np.ones(n)]) - return x, y - - -if __name__ == "__main__": - # Dataset - x_train, y_train = make_spirals(n_samples=1000, noise=0.3) - train_ds = NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - - # Hyperparameters - hparams = dict( - input_dim=x_train.shape[1], - nclasses=2, - prototypes_per_class=40, - prototype_initializer="stratified_random", - lr=0.05, - ) - - # Initialize the model - model_class = CBC - model = model_class(hparams, data=[x_train, y_train]) - - # Pure-positive reasonings - new_reasoning = torch.zeros_like( - model.reasoning_layer.reasoning_probabilities) - for i, label in enumerate(model.proto_layer.prototype_labels): - new_reasoning[0][0][i][int(label)] = 1.0 - - model.reasoning_layer.reasoning_probabilities.data = new_reasoning - - # Model summary - print(model) - - # Callbacks - vis = VisualizationCallback(x_train, - y_train, - prototype_model=hasattr(model, "prototypes")) - - # Setup trainer - trainer = pl.Trainer( - max_epochs=500, - callbacks=[ - vis, - ], - ) - - # Training loop - trainer.fit(model, train_loader) diff --git a/examples/cbc_spiral_with_glvq_start.py b/examples/cbc_spiral_with_glvq_start.py deleted file mode 100644 index b6122e7..0000000 --- a/examples/cbc_spiral_with_glvq_start.py +++ /dev/null @@ -1,146 +0,0 @@ -"""CBC example using the spirals dataset. - -This example shows how to jump start a model by transferring weights from -another more stable model. -""" - -import numpy as np -import pytorch_lightning as pl -import torch -from matplotlib import pyplot as plt -from torch.utils.data import DataLoader - -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.cbc import CBC -from prototorch.models.glvq import GLVQ - - -class VisualizationCallback(pl.Callback): - def __init__( - self, - x_train, - y_train, - prototype_model=True, - title="Prototype Visualization", - cmap="viridis", - ): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - self.prototype_model = prototype_model - - def on_epoch_end(self, trainer, pl_module): - if self.prototype_model: - protos = pl_module.prototypes - color = pl_module.prototype_labels - else: - protos = pl_module.components - color = "k" - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=color, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - - -def make_spirals(n_samples=500, noise=0.3): - def get_samples(n, delta_t): - points = [] - for i in range(n): - r = i / n_samples * 5 - t = 1.75 * i / n * 2 * np.pi + delta_t - x = r * np.sin(t) + np.random.rand(1) * noise - y = r * np.cos(t) + np.random.rand(1) * noise - points.append([x, y]) - return points - - n = n_samples // 2 - positive = get_samples(n=n, delta_t=0) - negative = get_samples(n=n, delta_t=np.pi) - x = np.concatenate( - [np.array(positive).reshape(n, -1), - np.array(negative).reshape(n, -1)], - axis=0) - y = np.concatenate([np.zeros(n), np.ones(n)]) - return x, y - - -def train(model, x_train, y_train, train_loader, epochs=100): - # Callbacks - vis = VisualizationCallback(x_train, - y_train, - prototype_model=hasattr(model, "prototypes")) - # Setup trainer - trainer = pl.Trainer( - max_epochs=epochs, - callbacks=[ - vis, - ], - ) - # Training loop - trainer.fit(model, train_loader) - - -if __name__ == "__main__": - # Dataset - x_train, y_train = make_spirals(n_samples=1000, noise=0.3) - train_ds = NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - - # Hyperparameters - hparams = dict( - input_dim=x_train.shape[1], - nclasses=2, - prototypes_per_class=40, - prototype_initializer="stratified_random", - lr=0.05, - ) - - # Initialize the model - glvq_model = GLVQ(hparams, data=[x_train, y_train]) - cbc_model = CBC(hparams, data=[x_train, y_train]) - - # Train GLVQ - train(glvq_model, x_train, y_train, train_loader, epochs=10) - - # Transfer Prototypes - cbc_model.proto_layer.load_state_dict(glvq_model.proto_layer.state_dict()) - # Pure-positive reasonings - new_reasoning = torch.zeros_like( - cbc_model.reasoning_layer.reasoning_probabilities) - for i, label in enumerate(cbc_model.proto_layer.prototype_labels): - new_reasoning[0][0][i][int(label)] = 1.0 - new_reasoning[1][0][i][1 - int(label)] = 1.0 - - cbc_model.reasoning_layer.reasoning_probabilities.data = new_reasoning - - # Train CBC - train(cbc_model, x_train, y_train, train_loader, epochs=50) diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 5ada61d..95982e7 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -1,131 +1,40 @@ """GLVQ example using the Iris dataset.""" -import argparse - -import numpy as np +import prototorch as pt import pytorch_lightning as pl import torch -from matplotlib import pyplot as plt -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader - -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.glvq import GLVQ - - -class GLVQIris(GLVQ): - @staticmethod - def add_model_specific_args(parent_parser): - parser = argparse.ArgumentParser(parents=[parent_parser], - add_help=False) - parser.add_argument("--epochs", type=int, default=1) - parser.add_argument("--lr", type=float, default=1e-1) - parser.add_argument("--batch_size", type=int, default=150) - parser.add_argument("--input_dim", type=int, default=2) - parser.add_argument("--nclasses", type=int, default=3) - parser.add_argument("--prototypes_per_class", type=int, default=3) - parser.add_argument("--prototype_initializer", - type=str, - default="stratified_mean") - return parser - - -class VisualizationCallback(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Prototype Visualization", - cmap="viridis"): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - - def on_epoch_end(self, trainer, pl_module): - protos = pl_module.prototypes - plabels = pl_module.prototype_labels - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=plabels, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) - if __name__ == "__main__": - # For best-practices when using `argparse` with `pytorch_lightning`, see - # https://pytorch-lightning.readthedocs.io/en/stable/common/hyperparameters.html - parser = argparse.ArgumentParser() - # Dataset + from sklearn.datasets import load_iris x_train, y_train = load_iris(return_X_y=True) x_train = x_train[:, [0, 2]] - train_ds = NumpyDataset(x_train, y_train) + train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) - # Add model specific args - parser = GLVQIris.add_model_specific_args(parser) - - # Callbacks - vis = VisualizationCallback(x_train, y_train) - - # Automatically add trainer-specific-args like `--gpus`, `--num_nodes` etc. - parser = pl.Trainer.add_argparse_args(parser) - - # Setup trainer - trainer = pl.Trainer.from_argparse_args( - parser, - max_epochs=10, - callbacks=[ - vis, - ], # comment this line out to disable the visualization + # Hyperparameters + hparams = dict( + nclasses=3, + prototypes_per_class=2, + prototype_initializer=pt.components.SMI(train_ds), + lr=0.01, ) - # trainer.tune(model) # Initialize the model - args = parser.parse_args() - model = GLVQIris(args, data=[x_train, y_train]) + model = pt.models.GLVQ(hparams) - # Model summary - print(model) + # Callbacks + vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) + + # Setup trainer + trainer = pl.Trainer( + max_epochs=50, + callbacks=[vis], + ) # Training loop trainer.fit(model, train_loader) - - # Save the model manually (use `pl.callbacks.ModelCheckpoint` to automate) - ckpt = "glvq_iris.ckpt" - trainer.save_checkpoint(ckpt) - - # Load the checkpoint - new_model = GLVQIris.load_from_checkpoint(checkpoint_path=ckpt) - - print(new_model) - - # Continue training - trainer.fit(new_model, train_loader) # TODO See why this fails! diff --git a/examples/glvq_iris_v1.py b/examples/glvq_iris_v1.py deleted file mode 100644 index aff2d26..0000000 --- a/examples/glvq_iris_v1.py +++ /dev/null @@ -1,40 +0,0 @@ -"""GLVQ example using the Iris dataset.""" - -import pytorch_lightning as pl -import torch -from prototorch.components import initializers as cinit -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.callbacks.visualization import VisGLVQ2D -from prototorch.models.glvq import GLVQ -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader - -if __name__ == "__main__": - # Dataset - x_train, y_train = load_iris(return_X_y=True) - x_train = x_train[:, [0, 2]] - train_ds = NumpyDataset(x_train, y_train) - - # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) - - # Hyperparameters - hparams = dict( - nclasses=3, - prototypes_per_class=2, - prototype_initializer=cinit.StratifiedMeanInitializer( - torch.Tensor(x_train), torch.Tensor(y_train)), - lr=0.01, - ) - - # Initialize the model - model = GLVQ(hparams, data=[x_train, y_train]) - - # Setup trainer - trainer = pl.Trainer( - max_epochs=50, - callbacks=[VisGLVQ2D(x_train, y_train)], - ) - - # Training loop - trainer.fit(model, train_loader) diff --git a/examples/glvq_mnist.py b/examples/glvq_mnist.py deleted file mode 100644 index 6ab5091..0000000 --- a/examples/glvq_mnist.py +++ /dev/null @@ -1,118 +0,0 @@ -"""GLVQ example using the MNIST dataset. - -This script also shows how to use Tensorboard for visualizing the prototypes. -""" - -import argparse - -import pytorch_lightning as pl -import torchvision -from torch.utils.data import DataLoader -from torchvision import transforms -from torchvision.datasets import MNIST - -from prototorch.models.glvq import ImageGLVQ - - -class VisualizationCallback(pl.Callback): - def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2): - super().__init__() - self.to_shape = to_shape - self.nrow = nrow - - def on_epoch_end(self, trainer, pl_module): - protos = pl_module.proto_layer.prototypes.detach().cpu() - protos_img = protos.reshape(self.to_shape) - grid = torchvision.utils.make_grid(protos_img, nrow=self.nrow) - # grid = grid.permute((1, 2, 0)) - tb = pl_module.logger.experiment - tb.add_image( - tag="MNIST Prototypes", - img_tensor=grid, - global_step=trainer.current_epoch, - dataformats="CHW", - ) - - -if __name__ == "__main__": - # Arguments - parser = argparse.ArgumentParser() - parser.add_argument("--epochs", - type=int, - default=10, - help="Epochs to train.") - parser.add_argument("--lr", - type=float, - default=0.001, - help="Learning rate.") - parser.add_argument("--batch_size", - type=int, - default=256, - help="Batch size.") - parser.add_argument("--gpus", - type=int, - default=0, - help="Number of GPUs to use.") - parser.add_argument("--ppc", - type=int, - default=1, - help="Prototypes-Per-Class.") - args = parser.parse_args() - - # Dataset - mnist_train = MNIST( - "./datasets", - train=True, - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ]), - ) - mnist_test = MNIST( - "./datasets", - train=False, - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ]), - ) - - # Dataloaders - train_loader = DataLoader(mnist_train, batch_size=1024) - test_loader = DataLoader(mnist_test, batch_size=1024) - - # Grab the full dataset to warm-start prototypes - x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train)))) - x = x.view(len(mnist_train), -1) - - # Hyperparameters - hparams = dict( - input_dim=28 * 28, - nclasses=10, - prototypes_per_class=1, - prototype_initializer="stratified_mean", - lr=args.lr, - ) - - # Initialize the model - model = ImageGLVQ(hparams, data=[x, y]) - - # Model summary - print(model) - - # Callbacks - vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc) - - # Setup trainer - trainer = pl.Trainer( - gpus=args.gpus, # change to use GPUs for training - max_epochs=args.epochs, - callbacks=[vis], - # accelerator="ddp_cpu", # DEBUG-ONLY - # num_processes=2, # DEBUG-ONLY - ) - - # Training loop - trainer.fit(model, train_loader, test_loader) diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py new file mode 100644 index 0000000..3fee454 --- /dev/null +++ b/examples/glvq_spiral.py @@ -0,0 +1,51 @@ +"""GLVQ example using the spiral dataset.""" + +import prototorch as pt +import pytorch_lightning as pl +import torch + + +class StopOnNaN(pl.Callback): + def __init__(self, param): + super().__init__() + self.param = param + + def on_epoch_end(self, trainer, pl_module, logs={}): + if torch.isnan(self.param).any(): + raise ValueError("NaN encountered. Stopping.") + + +if __name__ == "__main__": + # Dataset + train_ds = pt.datasets.Spiral(n_samples=600, noise=0.6) + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=256) + + # Hyperparameters + hparams = dict( + nclasses=2, + prototypes_per_class=20, + prototype_initializer=pt.components.SSI(train_ds, noise=1e-7), + transfer_function="sigmoid_beta", + transfer_beta=10.0, + lr=0.01, + ) + + # Initialize the model + model = pt.models.GLVQ(hparams) + + # Callbacks + vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True) + snan = StopOnNaN(model.proto_layer.components) + + # Setup trainer + trainer = pl.Trainer( + max_epochs=200, + callbacks=[vis, snan], + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py new file mode 100644 index 0000000..ba903e7 --- /dev/null +++ b/examples/gmlvq_iris.py @@ -0,0 +1,37 @@ +"""GMLVQ example using all four dimensions of the Iris dataset.""" + +import prototorch as pt +import pytorch_lightning as pl +import torch + +if __name__ == "__main__": + # Dataset + from sklearn.datasets import load_iris + x_train, y_train = load_iris(return_X_y=True) + train_ds = pt.datasets.NumpyDataset(x_train, y_train) + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) + # Hyperparameters + hparams = dict( + nclasses=3, + prototypes_per_class=1, + input_dim=x_train.shape[1], + latent_dim=x_train.shape[1], + prototype_initializer=pt.components.SMI(train_ds), + lr=0.01, + ) + + # Initialize the model + model = pt.models.GMLVQ(hparams) + + # Setup trainer + trainer = pl.Trainer(max_epochs=100) + + # Training loop + trainer.fit(model, train_loader) + + # Display the Lambda matrix + model.show_lambda() diff --git a/examples/liramlvq_tecator.py b/examples/liramlvq_tecator.py new file mode 100644 index 0000000..b7cc21a --- /dev/null +++ b/examples/liramlvq_tecator.py @@ -0,0 +1,45 @@ +"""Limited Rank Matrix LVQ example using the Tecator dataset.""" + +import prototorch as pt +import pytorch_lightning as pl +import torch + +if __name__ == "__main__": + # Dataset + train_ds = pt.datasets.Tecator(root="~/datasets/", train=True) + + # Reproducibility + pl.utilities.seed.seed_everything(seed=42) + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=32) + + # Hyperparameters + hparams = dict( + nclasses=2, + prototypes_per_class=2, + input_dim=100, + latent_dim=2, + prototype_initializer=pt.components.SMI(train_ds), + lr=0.001, + ) + + # Initialize the model + model = pt.models.GMLVQ(hparams) + + # Model summary + print(model) + + # Callbacks + vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1) + + # Namespace hook for the visualization to work + model.backbone = model.omega_layer + + # Setup trainer + trainer = pl.Trainer(max_epochs=200, callbacks=[vis]) + + # Training loop + trainer.fit(model, train_loader) diff --git a/examples/ng_iris.py b/examples/ng_iris.py index 16e954a..38e6162 100644 --- a/examples/ng_iris.py +++ b/examples/ng_iris.py @@ -1,51 +1,40 @@ """Neural Gas example using the Iris dataset.""" -import numpy as np +import prototorch as pt import pytorch_lightning as pl -from matplotlib import pyplot as plt -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.callbacks.visualization import VisNG2D -from prototorch.models.neural_gas import NeuralGas -from sklearn.datasets import load_iris -from sklearn.preprocessing import StandardScaler -from torch.utils.data import DataLoader +import torch if __name__ == "__main__": - # Dataset + # Prepare and pre-process the dataset + from sklearn.datasets import load_iris + from sklearn.preprocessing import StandardScaler x_train, y_train = load_iris(return_X_y=True) x_train = x_train[:, [0, 2]] scaler = StandardScaler() scaler.fit(x_train) x_train = scaler.transform(x_train) - train_ds = NumpyDataset(x_train, y_train) + train_ds = pt.datasets.NumpyDataset(x_train, y_train) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) # Hyperparameters - hparams = dict( - input_dim=x_train.shape[1], - num_prototypes=30, - lr=0.01, - ) + hparams = dict(num_prototypes=30, lr=0.03) # Initialize the model - model = NeuralGas(hparams) + model = pt.models.NeuralGas(hparams) # Model summary print(model) # Callbacks - vis = VisNG2D(x_train, y_train) + vis = pt.models.VisNG2D(data=train_ds) # Setup trainer - trainer = pl.Trainer( - max_epochs=100, - callbacks=[ - vis, - ], - ) + trainer = pl.Trainer(max_epochs=200, callbacks=[vis]) # Training loop trainer.fit(model, train_loader) diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index 897e3f0..a6390b2 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -1,17 +1,12 @@ """Siamese GLVQ example using all four dimensions of the Iris dataset.""" +import prototorch as pt import pytorch_lightning as pl import torch -from prototorch.components import (StratifiedMeanInitializer, - StratifiedSelectionInitializer) -from prototorch.datasets.abstract import NumpyDataset -from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D -from prototorch.models.glvq import SiameseGLVQ -from sklearn.datasets import load_iris -from torch.utils.data import DataLoader class Backbone(torch.nn.Module): + """Two fully connected layers with ReLU activation.""" def __init__(self, input_size=4, hidden_size=10, latent_size=2): super().__init__() self.input_size = input_size @@ -22,28 +17,36 @@ class Backbone(torch.nn.Module): self.relu = torch.nn.ReLU() def forward(self, x): - return self.relu(self.dense2(self.relu(self.dense1(x)))) + x = self.relu(self.dense1(x)) + out = self.relu(self.dense2(x)) + return out if __name__ == "__main__": # Dataset + from sklearn.datasets import load_iris x_train, y_train = load_iris(return_X_y=True) - train_ds = NumpyDataset(x_train, y_train) + train_ds = pt.datasets.NumpyDataset(x_train, y_train) + + # Reproducibility + pl.utilities.seed.seed_everything(seed=2) # Dataloaders - train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=150) # Hyperparameters hparams = dict( nclasses=3, - prototypes_per_class=1, - prototype_initializer=StratifiedMeanInitializer( - torch.Tensor(x_train), torch.Tensor(y_train)), - lr=0.01, + prototypes_per_class=2, + prototype_initializer=pt.components.SMI((x_train, y_train)), + proto_lr=0.001, + bb_lr=0.001, ) # Initialize the model - model = SiameseGLVQ( + model = pt.models.SiameseGLVQ( hparams, backbone_module=Backbone, ) @@ -52,7 +55,7 @@ if __name__ == "__main__": print(model) # Callbacks - vis = VisSiameseGLVQ2D(x_train, y_train) + vis = pt.models.VisSiameseGLVQ2D(data=(x_train, y_train), border=0.1) # Setup trainer trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index 2507e45..bde1ef1 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -1,5 +1,10 @@ from importlib.metadata import PackageNotFoundError, version +from .cbc import CBC +from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ +from .neural_gas import NeuralGas +from .vis import * + VERSION_FALLBACK = "uninstalled_version" try: __version__ = version(__name__.replace(".", "-")) diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py new file mode 100644 index 0000000..562b6d3 --- /dev/null +++ b/prototorch/models/abstract.py @@ -0,0 +1,23 @@ +import pytorch_lightning as pl +import torch +from torch.optim.lr_scheduler import ExponentialLR + + +class AbstractLightningModel(pl.LightningModule): + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) + scheduler = ExponentialLR(optimizer, + gamma=0.99, + last_epoch=-1, + verbose=False) + sch = { + "scheduler": scheduler, + "interval": "step", + } # called after each training step + return [optimizer], [sch] + + +class AbstractPrototypeModel(AbstractLightningModel): + @property + def prototypes(self): + return self.proto_layer.components.detach().cpu() diff --git a/prototorch/models/cbc.py b/prototorch/models/cbc.py index 5c20260..eff85df 100644 --- a/prototorch/models/cbc.py +++ b/prototorch/models/cbc.py @@ -1,10 +1,9 @@ import pytorch_lightning as pl import torch import torchmetrics - +from prototorch.components.components import Components from prototorch.functions.distances import euclidean_distance from prototorch.functions.similarities import cosine_similarity -from prototorch.modules.prototypes import Prototypes1D def rescaled_cosine_similarity(x, y): @@ -93,12 +92,8 @@ class CBC(pl.LightningModule): super().__init__() self.save_hyperparameters(hparams) self.margin = margin - self.proto_layer = Prototypes1D( - input_dim=self.hparams.input_dim, - nclasses=self.hparams.nclasses, - prototypes_per_class=self.hparams.prototypes_per_class, - prototype_initializer=self.hparams.prototype_initializer, - **kwargs) + self.component_layer = Components(self.hparams.num_components, + self.hparams.component_initializer) # self.similarity = CosineSimilarity() self.similarity = similarity self.backbone = backbone_class() @@ -110,7 +105,7 @@ class CBC(pl.LightningModule): @property def components(self): - return self.proto_layer.prototypes.detach().cpu() + return self.component_layer.components.detach().cpu() @property def reasonings(self): @@ -126,7 +121,7 @@ class CBC(pl.LightningModule): def forward(self, x): self.sync_backbones() - protos, _ = self.proto_layer() + protos = self.component_layer() latent_x = self.backbone(x) latent_protos = self.backbone_dependent(protos) @@ -167,4 +162,4 @@ class ImageCBC(CBC): """ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): # super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) - self.proto_layer.prototypes.data.clamp_(0.0, 1.0) + self.component_layer.prototypes.data.clamp_(0.0, 1.0) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 61749d7..6fe76d3 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -1,11 +1,11 @@ -import pytorch_lightning as pl import torch import torchmetrics from prototorch.components import LabeledComponents +from prototorch.functions.activations import get_activation from prototorch.functions.competitions import wtac -from prototorch.functions.distances import euclidean_distance +from prototorch.functions.distances import (euclidean_distance, omega_distance, + squared_euclidean_distance) from prototorch.functions.losses import glvq_loss -from prototorch.modules.prototypes import Prototypes1D from .abstract import AbstractPrototypeModel @@ -19,50 +19,53 @@ class GLVQ(AbstractPrototypeModel): # Default Values self.hparams.setdefault("distance", euclidean_distance) + self.hparams.setdefault("optimizer", torch.optim.Adam) + self.hparams.setdefault("transfer_function", "identity") + self.hparams.setdefault("transfer_beta", 10.0) self.proto_layer = LabeledComponents( labels=(self.hparams.nclasses, self.hparams.prototypes_per_class), initializer=self.hparams.prototype_initializer) + self.transfer_function = get_activation(self.hparams.transfer_function) self.train_acc = torchmetrics.Accuracy() @property def prototype_labels(self): - return self.proto_layer.component_labels.detach().numpy() + return self.proto_layer.component_labels.detach().cpu() def forward(self, x): protos, _ = self.proto_layer() dis = self.hparams.distance(x, protos) return dis - def training_step(self, train_batch, batch_idx): + def training_step(self, train_batch, batch_idx, optimizer_idx=None): x, y = train_batch - x = x.view(x.size(0), -1) + x = x.view(x.size(0), -1) # flatten dis = self(x) plabels = self.proto_layer.component_labels mu = glvq_loss(dis, y, prototype_labels=plabels) - loss = mu.sum(dim=0) - self.log("train_loss", loss) + batch_loss = self.transfer_function(mu, + beta=self.hparams.transfer_beta) + loss = batch_loss.sum(dim=0) + + # Compute training accuracy with torch.no_grad(): preds = wtac(dis, plabels) - # self.train_acc.update(preds.int(), y.int()) - self.train_acc( - preds.int(), - y.int()) # FloatTensors are assumed to be class probabilities - self.log( - "acc", - self.train_acc, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return loss - # def training_epoch_end(self, outs): - # # Calling `self.train_acc.compute()` is - # # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)` - # self.log("train_acc_epoch", self.train_acc.compute()) + self.train_acc(preds.int(), y.int()) + # `.int()` because FloatTensors are assumed to be class probabilities + + # Logging + self.log("train_loss", loss) + self.log("acc", + self.train_acc, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True) + + return loss def predict(self, x): # model.eval() # ?! @@ -76,8 +79,9 @@ class GLVQ(AbstractPrototypeModel): class ImageGLVQ(GLVQ): """GLVQ for training on image data. - GLVQ model that constrains the prototypes to the range [0, 1] by - clamping after updates. + GLVQ model that constrains the prototypes to the range [0, 1] by clamping + after updates. + """ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.proto_layer.components.data.clamp_(0.0, 1.0) @@ -89,6 +93,155 @@ class SiameseGLVQ(GLVQ): GLVQ model that applies an arbitrary transformation on the inputs and the prototypes before computing the distances between them. The weights in the transformation pipeline are only learned from the inputs. + + """ + def __init__(self, + hparams, + backbone_module=torch.nn.Identity, + backbone_params={}, + sync=True, + **kwargs): + super().__init__(hparams, **kwargs) + self.backbone = backbone_module(**backbone_params) + self.backbone_dependent = backbone_module( + **backbone_params).requires_grad_(False) + self.sync = sync + + def sync_backbones(self): + master_state = self.backbone.state_dict() + self.backbone_dependent.load_state_dict(master_state, strict=True) + + def configure_optimizers(self): + optim = self.hparams.optimizer + proto_opt = optim(self.proto_layer.parameters(), + lr=self.hparams.proto_lr) + if list(self.backbone.parameters()): + # only add an optimizer is the backbone has trainable parameters + # otherwise, the next line fails + bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr) + return proto_opt, bb_opt + else: + return proto_opt + + def forward(self, x): + if self.sync: + self.sync_backbones() + protos, _ = self.proto_layer() + latent_x = self.backbone(x) + latent_protos = self.backbone_dependent(protos) + dis = euclidean_distance(latent_x, latent_protos) + return dis + + def predict_latent(self, x): + """Predict `x` assuming it is already embedded in the latent space. + + Only the prototypes are embedded in the latent space using the + backbone. + + """ + # model.eval() # ?! + with torch.no_grad(): + protos, plabels = self.proto_layer() + latent_protos = self.backbone_dependent(protos) + d = euclidean_distance(x, latent_protos) + y_pred = wtac(d, plabels) + return y_pred.numpy() + + +class GRLVQ(GLVQ): + """Generalized Relevance Learning Vector Quantization.""" + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.relevances = torch.nn.parameter.Parameter( + torch.ones(self.hparams.input_dim)) + + def forward(self, x): + protos, _ = self.proto_layer() + dis = omega_distance(x, protos, torch.diag(self.relevances)) + return dis + + def backbone(self, x): + return x @ torch.diag(self.relevances) + + @property + def relevance_profile(self): + return self.relevances.detach().cpu() + + def predict_latent(self, x): + """Predict `x` assuming it is already embedded in the latent space. + + Only the prototypes are embedded in the latent space using the + backbone. + + """ + # model.eval() # ?! + with torch.no_grad(): + protos, plabels = self.proto_layer() + latent_protos = protos @ torch.diag(self.relevances) + d = squared_euclidean_distance(x, latent_protos) + y_pred = wtac(d, plabels) + return y_pred.numpy() + + +class GMLVQ(GLVQ): + """Generalized Matrix Learning Vector Quantization.""" + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.omega_layer = torch.nn.Linear(self.hparams.input_dim, + self.hparams.latent_dim, + bias=False) + + @property + def omega_matrix(self): + return self.omega_layer.weight.detach().cpu() + + @property + def lambda_matrix(self): + omega = self.omega_layer.weight + lam = omega @ omega.T + return lam.detach().cpu() + + def show_lambda(self): + import matplotlib.pyplot as plt + title = "Lambda matrix" + plt.figure(title) + plt.title(title) + plt.imshow(self.lambda_matrix, cmap="gray") + plt.axis("off") + plt.colorbar() + plt.show(block=True) + + def forward(self, x): + protos, _ = self.proto_layer() + latent_x = self.omega_layer(x) + latent_protos = self.omega_layer(protos) + dis = squared_euclidean_distance(latent_x, latent_protos) + return dis + + def predict_latent(self, x): + """Predict `x` assuming it is already embedded in the latent space. + + Only the prototypes are embedded in the latent space using the + backbone. + + """ + # model.eval() # ?! + with torch.no_grad(): + protos, plabels = self.proto_layer() + latent_protos = self.omega_layer(protos) + d = squared_euclidean_distance(x, latent_protos) + y_pred = wtac(d, plabels) + return y_pred.numpy() + + +class LVQMLN(GLVQ): + """Learning Vector Quantization Multi-Layer Network. + + GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT + on the prototypes before computing the distances between them. This of + course, means that the prototypes no longer live the input space, but + rather in the embedding space. + """ def __init__(self, hparams, @@ -97,28 +250,17 @@ class SiameseGLVQ(GLVQ): **kwargs): super().__init__(hparams, **kwargs) self.backbone = backbone_module(**backbone_params) - self.backbone_dependent = backbone_module( - **backbone_params).requires_grad_(False) - - def sync_backbones(self): - master_state = self.backbone.state_dict() - self.backbone_dependent.load_state_dict(master_state, strict=True) def forward(self, x): - self.sync_backbones() - protos, _ = self.proto_layer() - + latent_protos, _ = self.proto_layer() latent_x = self.backbone(x) - latent_protos = self.backbone_dependent(protos) - dis = euclidean_distance(latent_x, latent_protos) return dis def predict_latent(self, x): - # model.eval() # ?! + """Predict `x` assuming it is already embedded in the latent space.""" with torch.no_grad(): - protos, plabels = self.proto_layer() - latent_protos = self.backbone_dependent(protos) + latent_protos, plabels = self.proto_layer() d = euclidean_distance(x, latent_protos) y_pred = wtac(d, plabels) return y_pred.numpy() diff --git a/prototorch/models/neural_gas.py b/prototorch/models/neural_gas.py index d98bfa0..bebd289 100644 --- a/prototorch/models/neural_gas.py +++ b/prototorch/models/neural_gas.py @@ -1,9 +1,7 @@ -import pytorch_lightning as pl import torch from prototorch.components import Components from prototorch.components import initializers as cinit from prototorch.functions.distances import euclidean_distance -from prototorch.modules import Prototypes1D from prototorch.modules.losses import NeuralGasEnergy from .abstract import AbstractPrototypeModel diff --git a/prototorch/models/callbacks/visualization.py b/prototorch/models/vis.py similarity index 74% rename from prototorch/models/callbacks/visualization.py rename to prototorch/models/vis.py index 11bc729..6099bc3 100644 --- a/prototorch/models/callbacks/visualization.py +++ b/prototorch/models/vis.py @@ -9,6 +9,7 @@ from prototorch.utils.celluloid import Camera from prototorch.utils.colors import color_scheme from prototorch.utils.utils import (gif_from_dir, make_directory, prettify_string) +from torch.utils.data import DataLoader, Dataset class VisWeights(pl.Callback): @@ -261,29 +262,82 @@ class VisPointProtos(VisWeights): self._show_and_save(epoch) -class VisGLVQ2D(pl.Callback): +class Vis2DAbstract(pl.Callback): def __init__(self, - x_train, - y_train, + data, title="Prototype Visualization", - cmap="viridis"): + cmap="viridis", + border=1, + resolution=50, + tensorboard=False, + show_last_only=False, + pause_time=0.1, + block=False): super().__init__() - self.x_train = x_train - self.y_train = y_train + + if isinstance(data, Dataset): + x, y = next(iter(DataLoader(data, batch_size=len(data)))) + x = x.view(len(data), -1) # flatten + else: + x, y = data + self.x_train = x + self.y_train = y + self.title = title self.fig = plt.figure(self.title) self.cmap = cmap + self.border = border + self.resolution = resolution + self.tensorboard = tensorboard + self.show_last_only = show_last_only + self.pause_time = pause_time + self.block = block - def on_epoch_end(self, trainer, pl_module): - protos = pl_module.prototypes - plabels = pl_module.prototype_labels - x_train, y_train = self.x_train, self.y_train + def setup_ax(self, xlabel=None, ylabel=None): ax = self.fig.gca() ax.cla() ax.set_title(self.title) ax.axis("off") - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") + if xlabel: + ax.set_xlabel("Data dimension 1") + if ylabel: + ax.set_ylabel("Data dimension 2") + return ax + + def get_mesh_input(self, x): + x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border + y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border + xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / self.resolution), + np.arange(y_min, y_max, 1 / self.resolution)) + mesh_input = np.c_[xx.ravel(), yy.ravel()] + return mesh_input, xx, yy + + def add_to_tensorboard(self, trainer, pl_module): + tb = pl_module.logger.experiment + tb.add_figure(tag=f"{self.title}", + figure=self.fig, + global_step=trainer.current_epoch, + close=False) + + def log_and_display(self, trainer, pl_module): + if self.tensorboard: + self.add_to_tensorboard(trainer, pl_module) + if not self.block: + plt.pause(self.pause_time) + else: + plt.show(block=True) + + +class VisGLVQ2D(Vis2DAbstract): + def on_epoch_end(self, trainer, pl_module): + if self.show_last_only: + if trainer.current_epoch != trainer.max_epochs - 1: + return + protos = pl_module.prototypes + plabels = pl_module.prototype_labels + x_train, y_train = self.x_train, self.y_train + ax = self.setup_ax(xlabel="Data dimension 1", + ylabel="Data dimension 2") ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter( protos[:, 0], @@ -295,43 +349,25 @@ class VisGLVQ2D(pl.Callback): s=50, ) x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] + mesh_input, xx, yy = self.get_mesh_input(x) y_pred = pl_module.predict(torch.Tensor(mesh_input)) y_pred = y_pred.reshape(xx.shape) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - plt.pause(0.1) + # ax.set_xlim(left=x_min + 0, right=x_max - 0) + # ax.set_ylim(bottom=y_min + 0, top=y_max - 0) + + self.log_and_display(trainer, pl_module) -class VisSiameseGLVQ2D(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Prototype Visualization", - cmap="viridis"): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - +class VisSiameseGLVQ2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): protos = pl_module.prototypes plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train x_train = pl_module.backbone(torch.Tensor(x_train)).detach() protos = pl_module.backbone(torch.Tensor(protos)).detach() - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.axis("off") + ax = self.setup_ax() ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter( protos[:, 0], @@ -343,54 +379,54 @@ class VisSiameseGLVQ2D(pl.Callback): s=50, ) x = np.vstack((x_train, protos)) - x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 - y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), - np.arange(y_min, y_max, 1 / 50)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] + mesh_input, xx, yy = self.get_mesh_input(x) y_pred = pl_module.predict_latent(torch.Tensor(mesh_input)) y_pred = y_pred.reshape(xx.shape) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - ax.set_xlim(left=x_min + 0, right=x_max - 0) - ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - tb = pl_module.logger.experiment - tb.add_figure( - tag=f"{self.title}", - figure=self.fig, - global_step=trainer.current_epoch, - close=False, - ) - plt.pause(0.1) + # ax.set_xlim(left=x_min + 0, right=x_max - 0) + # ax.set_ylim(bottom=y_min + 0, top=y_max - 0) + + self.log_and_display(trainer, pl_module) -class VisNG2D(pl.Callback): - def __init__(self, - x_train, - y_train, - title="Neural Gas Visualization", - cmap="viridis"): - super().__init__() - self.x_train = x_train - self.y_train = y_train - self.title = title - self.fig = plt.figure(self.title) - self.cmap = cmap - +class VisCBC2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): + x_train, y_train = self.x_train, self.y_train + protos = pl_module.components + ax = self.setup_ax(xlabel="Data dimension 1", + ylabel="Data dimension 2") + ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") + ax.scatter( + protos[:, 0], + protos[:, 1], + c="w", + cmap=self.cmap, + edgecolor="k", + marker="D", + s=50, + ) + x = np.vstack((x_train, protos)) + mesh_input, xx, yy = self.get_mesh_input(x) + y_pred = pl_module.predict(torch.Tensor(mesh_input)) + y_pred = y_pred.reshape(xx.shape) + + ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) + # ax.set_xlim(left=x_min + 0, right=x_max - 0) + # ax.set_ylim(bottom=y_min + 0, top=y_max - 0) + + self.log_and_display(trainer, pl_module) + + +class VisNG2D(Vis2DAbstract): + def on_epoch_end(self, trainer, pl_module): + x_train, y_train = self.x_train, self.y_train protos = pl_module.prototypes cmat = pl_module.topology_layer.cmat.cpu().numpy() - # Visualize the data and the prototypes - ax = self.fig.gca() - ax.cla() - ax.set_title(self.title) - ax.set_xlabel("Data dimension 1") - ax.set_ylabel("Data dimension 2") - ax.scatter(self.x_train[:, 0], - self.x_train[:, 1], - c=self.y_train, - edgecolor="k") + ax = self.setup_ax(xlabel="Data dimension 1", + ylabel="Data dimension 2") + ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter( protos[:, 0], protos[:, 1], @@ -410,4 +446,4 @@ class VisNG2D(pl.Callback): "k-", ) - plt.pause(0.01) + self.log_and_display(trainer, pl_module)