diff --git a/examples/grlvq_iris.py b/examples/grlvq_iris.py new file mode 100644 index 0000000..01c31cb --- /dev/null +++ b/examples/grlvq_iris.py @@ -0,0 +1,62 @@ +"""GMLVQ example using all four dimensions of the Iris dataset.""" + +import pytorch_lightning as pl +import torch +from prototorch.components import initializers as cinit +from prototorch.datasets.abstract import NumpyDataset +from sklearn.datasets import load_iris +from torch.utils.data import DataLoader + +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import GRLVQ + +from sklearn.preprocessing import StandardScaler + + +class PrintRelevanceCallback(pl.Callback): + def on_epoch_end(self, trainer, pl_module: GRLVQ): + print(pl_module.relevance_profile) + + +if __name__ == "__main__": + # Dataset + x_train, y_train = load_iris(return_X_y=True) + x_train = x_train[:, [0, 2]] + scaler = StandardScaler() + scaler.fit(x_train) + x_train = scaler.transform(x_train) + train_ds = NumpyDataset(x_train, y_train) + + # Dataloaders + train_loader = DataLoader(train_ds, + num_workers=0, + batch_size=50, + shuffle=True) + + # Hyperparameters + hparams = dict( + nclasses=3, + prototypes_per_class=1, + #prototype_initializer=cinit.SMI(torch.Tensor(x_train), + # torch.Tensor(y_train)), + prototype_initializer=cinit.UniformInitializer(2), + input_dim=x_train.shape[1], + lr=0.1, + #transfer_function="sigmoid_beta", + ) + + # Initialize the model + model = GRLVQ(hparams) + + # Model summary + print(model) + + # Callbacks + vis = VisSiameseGLVQ2D(x_train, y_train) + debug = PrintRelevanceCallback() + + # Setup trainer + trainer = pl.Trainer(max_epochs=200, callbacks=[vis, debug]) + + # Training loop + trainer.fit(model, train_loader) diff --git a/examples/grlvq_spiral.py b/examples/grlvq_spiral.py new file mode 100644 index 0000000..61d754c --- /dev/null +++ b/examples/grlvq_spiral.py @@ -0,0 +1,57 @@ +"""GMLVQ example using all four dimensions of the Iris dataset.""" + +import pytorch_lightning as pl +import torch +from prototorch.components import initializers as cinit +from prototorch.datasets.abstract import NumpyDataset +from sklearn.datasets import load_iris +from torch.utils.data import DataLoader + +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import GRLVQ + +from sklearn.preprocessing import StandardScaler + +from prototorch.datasets.spiral import make_spiral + + +class PrintRelevanceCallback(pl.Callback): + def on_epoch_end(self, trainer, pl_module: GRLVQ): + print(pl_module.relevance_profile) + + +if __name__ == "__main__": + # Dataset + x_train, y_train = make_spiral(n_samples=1000, noise=0.3) + train_ds = NumpyDataset(x_train, y_train) + + # Dataloaders + train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + + # Hyperparameters + hparams = dict( + nclasses=2, + prototypes_per_class=20, + prototype_initializer=cinit.SSI(torch.Tensor(x_train), + torch.Tensor(y_train)), + #prototype_initializer=cinit.UniformInitializer(2), + input_dim=x_train.shape[1], + lr=0.1, + #transfer_function="sigmoid_beta", + ) + + # Initialize the model + model = GRLVQ(hparams) + + # Model summary + print(model) + + # Callbacks + vis = VisSiameseGLVQ2D(x_train, y_train) + debug = PrintRelevanceCallback() + + # Setup trainer + trainer = pl.Trainer(max_epochs=200, callbacks=[vis, debug]) + + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index ff1da61..868d4c9 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -3,7 +3,7 @@ import torchmetrics from prototorch.components import LabeledComponents from prototorch.functions.activations import get_activation from prototorch.functions.competitions import wtac -from prototorch.functions.distances import (euclidean_distance, +from prototorch.functions.distances import (euclidean_distance, omega_distance, squared_euclidean_distance) from prototorch.functions.losses import glvq_loss @@ -32,7 +32,7 @@ class GLVQ(AbstractPrototypeModel): @property def prototype_labels(self): - return self.proto_layer.component_labels.detach().numpy() + return self.proto_layer.component_labels.detach().cpu() def forward(self, x): protos, _ = self.proto_layer() @@ -148,6 +148,41 @@ class SiameseGLVQ(GLVQ): return y_pred.numpy() +class GRLVQ(GLVQ): + """Generalized Relevance Learning Vector Quantization.""" + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.relevances = torch.nn.parameter.Parameter( + torch.ones(self.hparams.input_dim)) + + def forward(self, x): + protos, _ = self.proto_layer() + dis = omega_distance(x, protos, torch.diag(self.relevances)) + return dis + + def backbone(self, x): + return x @ torch.diag(self.relevances) + + @property + def relevance_profile(self): + return self.relevances.detach().cpu() + + def predict_latent(self, x): + """Predict `x` assuming it is already embedded in the latent space. + + Only the prototypes are embedded in the latent space using the + backbone. + + """ + # model.eval() # ?! + with torch.no_grad(): + protos, plabels = self.proto_layer() + latent_protos = protos @ torch.diag(self.relevances) + d = squared_euclidean_distance(x, latent_protos) + y_pred = wtac(d, plabels) + return y_pred.numpy() + + class GMLVQ(GLVQ): """Generalized Matrix Learning Vector Quantization.""" def __init__(self, hparams, **kwargs):