diff --git a/README.md b/README.md index 8750108..154d8e7 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ To assist in the development process, you may also find it useful to install ## Available models -- K-Nearest Neighbors (KNN) +- k-Nearest Neighbors (KNN) - Learning Vector Quantization 1 (LVQ1) - Generalized Learning Vector Quantization (GLVQ) - Generalized Relevance Learning Vector Quantization (GRLVQ) @@ -68,6 +68,7 @@ To assist in the development process, you may also find it useful to install ## Planned models +- Median-LVQ - Local-Matrix GMLVQ - Generalized Tangent Learning Vector Quantization (GTLVQ) - Robust Soft Learning Vector Quantization (RSLVQ) diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 94c15ce..49c0c12 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -21,12 +21,13 @@ if __name__ == "__main__": prototypes_per_class = 2 hparams = dict( distribution=(nclasses, prototypes_per_class), - prototype_initializer=pt.components.SMI(train_ds), lr=0.01, ) # Initialize the model - model = pt.models.GLVQ(hparams, optimizer=torch.optim.Adam) + model = pt.models.GLVQ(hparams, + optimizer=torch.optim.Adam, + prototype_initializer=pt.components.SMI(train_ds)) # Callbacks vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py index fea4bf0..a1fd95f 100644 --- a/examples/glvq_spiral.py +++ b/examples/glvq_spiral.py @@ -29,14 +29,15 @@ if __name__ == "__main__": prototypes_per_class = 20 hparams = dict( distribution=(nclasses, prototypes_per_class), - prototype_initializer=pt.components.SSI(train_ds, noise=1e-1), transfer_function="sigmoid_beta", transfer_beta=10.0, lr=0.01, ) # Initialize the model - model = pt.models.GLVQ(hparams) + model = pt.models.GLVQ(hparams, + prototype_initializer=pt.components.SSI(train_ds, + noise=1e-1)) # Callbacks vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True) diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py index c8c5425..65430a4 100644 --- a/examples/gmlvq_iris.py +++ b/examples/gmlvq_iris.py @@ -21,12 +21,12 @@ if __name__ == "__main__": distribution=(nclasses, prototypes_per_class), input_dim=x_train.shape[1], latent_dim=x_train.shape[1], - prototype_initializer=pt.components.SMI(train_ds), lr=0.01, ) # Initialize the model - model = pt.models.GMLVQ(hparams) + model = pt.models.GMLVQ(hparams, + prototype_initializer=pt.components.SMI(train_ds)) # Setup trainer trainer = pl.Trainer(max_epochs=100) diff --git a/examples/gmlvq_mnist.py b/examples/gmlvq_mnist.py new file mode 100644 index 0000000..3729930 --- /dev/null +++ b/examples/gmlvq_mnist.py @@ -0,0 +1,68 @@ +"""GMLVQ example using the MNIST dataset.""" + +import prototorch as pt +import pytorch_lightning as pl +import torch +from torchvision import transforms +from torchvision.datasets import MNIST + +if __name__ == "__main__": + # Dataset + train_ds = MNIST( + "~/datasets", + train=True, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + ]), + ) + test_ds = MNIST( + "~/datasets", + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + ]), + ) + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=256) + test_loader = torch.utils.data.DataLoader(test_ds, + num_workers=0, + batch_size=256) + + # Hyperparameters + nclasses = 10 + prototypes_per_class = 2 + hparams = dict( + input_dim=28 * 28, + latent_dim=28 * 28, + distribution=(nclasses, prototypes_per_class), + lr=0.01, + ) + + # Initialize the model + model = pt.models.ImageGMLVQ( + hparams, + optimizer=torch.optim.Adam, + prototype_initializer=pt.components.SMI(train_ds), + ) + + # Callbacks + vis = pt.models.VisImgComp(data=train_ds, + nrow=5, + show=False, + tensorboard=True) + + # Setup trainer + trainer = pl.Trainer( + max_epochs=50, + callbacks=[vis], + # overfit_batches=1, + # fast_dev_run=3, + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/examples/liramlvq_tecator.py b/examples/liramlvq_tecator.py index 1b07a0d..7e68b6b 100644 --- a/examples/liramlvq_tecator.py +++ b/examples/liramlvq_tecator.py @@ -23,12 +23,12 @@ if __name__ == "__main__": distribution=(nclasses, prototypes_per_class), input_dim=100, latent_dim=2, - prototype_initializer=pt.components.SMI(train_ds), lr=0.001, ) # Initialize the model - model = pt.models.GMLVQ(hparams) + model = pt.models.GMLVQ(hparams, + prototype_initializer=pt.components.SMI(train_ds)) # Callbacks vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1) diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index c5dc6d5..b8cca23 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -37,7 +37,6 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( distribution=[1, 2, 3], - prototype_initializer=pt.components.SMI(train_ds), proto_lr=0.01, bb_lr=0.01, ) @@ -45,6 +44,7 @@ if __name__ == "__main__": # Initialize the model model = pt.models.SiameseGLVQ( hparams, + prototype_initializer=pt.components.SMI(train_ds), backbone_module=Backbone, ) diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index 37a4b5b..0f0a8a8 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -2,7 +2,7 @@ from importlib.metadata import PackageNotFoundError, version from .cbc import CBC from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ, - SiameseGLVQ) + ImageGMLVQ, SiameseGLVQ) from .knn import KNN from .neural_gas import NeuralGas from .vis import * diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index 082b7ed..d7f57c5 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -8,6 +8,11 @@ class AbstractPrototypeModel(pl.LightningModule): def prototypes(self): return self.proto_layer.components.detach().cpu() + @property + def components(self): + """Only an alias for the prototypes.""" + return self.prototypes + def configure_optimizers(self): optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr) scheduler = ExponentialLR(optimizer, @@ -19,3 +24,8 @@ class AbstractPrototypeModel(pl.LightningModule): "interval": "step", } # called after each training step return [optimizer], [sch] + + +class PrototypeImageModel(pl.LightningModule): + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.proto_layer.components.data.clamp_(0.0, 1.0) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 1766064..2e1b366 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -5,9 +5,18 @@ from prototorch.functions.activations import get_activation from prototorch.functions.competitions import wtac from prototorch.functions.distances import (euclidean_distance, omega_distance, squared_euclidean_distance) +from prototorch.functions.helper import get_flat from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss +from prototorch.modules.mappings import OmegaMapping -from .abstract import AbstractPrototypeModel +from .abstract import AbstractPrototypeModel, PrototypeImageModel + + +class GLVQ(AbstractPrototypeModel): + """Generalized Learning Vector Quantization.""" + + +from .abstract import AbstractPrototypeModel, PrototypeImageModel class GLVQ(AbstractPrototypeModel): @@ -18,6 +27,7 @@ class GLVQ(AbstractPrototypeModel): self.save_hyperparameters(hparams) self.optimizer = kwargs.get("optimizer", torch.optim.Adam) + prototype_initializer = kwargs.get("prototype_initializer", None) # Default Values self.hparams.setdefault("distance", euclidean_distance) @@ -26,7 +36,7 @@ class GLVQ(AbstractPrototypeModel): self.proto_layer = LabeledComponents( distribution=self.hparams.distribution, - initializer=self.hparams.prototype_initializer) + initializer=prototype_initializer) self.transfer_function = get_activation(self.hparams.transfer_function) self.train_acc = torchmetrics.Accuracy() @@ -44,7 +54,6 @@ class GLVQ(AbstractPrototypeModel): def training_step(self, train_batch, batch_idx, optimizer_idx=None): x, y = train_batch - x = x.view(x.size(0), -1) # flatten dis = self(x) plabels = self.proto_layer.component_labels mu = self.loss(dis, y, prototype_labels=plabels) @@ -95,15 +104,14 @@ class LVQ21(GLVQ): self.optimizer = torch.optim.SGD -class ImageGLVQ(GLVQ): +class ImageGLVQ(GLVQ, PrototypeImageModel): """GLVQ for training on image data. GLVQ model that constrains the prototypes to the range [0, 1] by clamping after updates. """ - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.proto_layer.components.data.clamp_(0.0, 1.0) + pass class SiameseGLVQ(GLVQ): @@ -235,6 +243,7 @@ class GMLVQ(GLVQ): def forward(self, x): protos, _ = self.proto_layer() + x, protos = get_flat(x, protos) latent_x = self.omega_layer(x) latent_protos = self.omega_layer(protos) dis = squared_euclidean_distance(latent_x, latent_protos) @@ -256,6 +265,16 @@ class GMLVQ(GLVQ): return y_pred.numpy() +class ImageGMLVQ(GMLVQ, PrototypeImageModel): + """GMLVQ for training on image data. + + GMLVQ model that constrains the prototypes to the range [0, 1] by clamping + after updates. + + """ + pass + + class LVQMLN(GLVQ): """Learning Vector Quantization Multi-Layer Network. diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index d788e0e..7ab9538 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -3,6 +3,7 @@ import os import numpy as np import pytorch_lightning as pl import torch +import torchvision from matplotlib import pyplot as plt from matplotlib.offsetbox import AnchoredText from prototorch.utils.celluloid import Camera @@ -270,6 +271,7 @@ class Vis2DAbstract(pl.Callback): border=1, resolution=50, show_protos=True, + show=True, tensorboard=False, show_last_only=False, pause_time=0.1, @@ -290,6 +292,7 @@ class Vis2DAbstract(pl.Callback): self.border = border self.resolution = resolution self.show_protos = show_protos + self.show = show self.tensorboard = tensorboard self.show_last_only = show_last_only self.pause_time = pause_time @@ -352,10 +355,11 @@ class Vis2DAbstract(pl.Callback): def log_and_display(self, trainer, pl_module): if self.tensorboard: self.add_to_tensorboard(trainer, pl_module) - if not self.block: - plt.pause(self.pause_time) - else: - plt.show(block=True) + if self.show: + if not self.block: + plt.pause(self.pause_time) + else: + plt.show(block=True) def on_train_end(self, trainer, pl_module): plt.show() @@ -458,3 +462,50 @@ class VisNG2D(Vis2DAbstract): ) self.log_and_display(trainer, pl_module) + + +class VisImgComp(Vis2DAbstract): + def __init__(self, + *args, + random_data=0, + dataformats="CHW", + nrow=2, + **kwargs): + super().__init__(*args, **kwargs) + self.random_data = random_data + self.dataformats = dataformats + self.nrow = nrow + + def on_epoch_end(self, trainer, pl_module): + if not self.precheck(trainer): + return True + + if self.show: + components = pl_module.components + grid = torchvision.utils.make_grid(components, nrow=self.nrow) + plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap) + + self.log_and_display(trainer, pl_module) + + def add_to_tensorboard(self, trainer, pl_module): + tb = pl_module.logger.experiment + + components = pl_module.components + grid = torchvision.utils.make_grid(components, nrow=self.nrow) + tb.add_image( + tag="Components", + img_tensor=grid, + global_step=trainer.current_epoch, + dataformats=self.dataformats, + ) + + if self.random_data: + ind = np.random.choice(len(self.x_train), + size=self.random_data, + replace=False) + data_img = self.x_train[ind] + grid = torchvision.utils.make_grid(data_img, nrow=self.nrow) + tb.add_image(tag="Data", + img_tensor=grid, + global_step=trainer.current_epoch, + dataformats=self.dataformats)