Add GRLVQ with examples.
This commit is contained in:
		
							
								
								
									
										62
									
								
								examples/grlvq_iris.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								examples/grlvq_iris.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | |||||||
|  | """GMLVQ example using all four dimensions of the Iris dataset.""" | ||||||
|  |  | ||||||
|  | import pytorch_lightning as pl | ||||||
|  | import torch | ||||||
|  | from prototorch.components import initializers as cinit | ||||||
|  | from prototorch.datasets.abstract import NumpyDataset | ||||||
|  | from sklearn.datasets import load_iris | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  |  | ||||||
|  | from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D | ||||||
|  | from prototorch.models.glvq import GRLVQ | ||||||
|  |  | ||||||
|  | from sklearn.preprocessing import StandardScaler | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PrintRelevanceCallback(pl.Callback): | ||||||
|  |     def on_epoch_end(self, trainer, pl_module: GRLVQ): | ||||||
|  |         print(pl_module.relevance_profile) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     # Dataset | ||||||
|  |     x_train, y_train = load_iris(return_X_y=True) | ||||||
|  |     x_train = x_train[:, [0, 2]] | ||||||
|  |     scaler = StandardScaler() | ||||||
|  |     scaler.fit(x_train) | ||||||
|  |     x_train = scaler.transform(x_train) | ||||||
|  |     train_ds = NumpyDataset(x_train, y_train) | ||||||
|  |  | ||||||
|  |     # Dataloaders | ||||||
|  |     train_loader = DataLoader(train_ds, | ||||||
|  |                               num_workers=0, | ||||||
|  |                               batch_size=50, | ||||||
|  |                               shuffle=True) | ||||||
|  |  | ||||||
|  |     # Hyperparameters | ||||||
|  |     hparams = dict( | ||||||
|  |         nclasses=3, | ||||||
|  |         prototypes_per_class=1, | ||||||
|  |         #prototype_initializer=cinit.SMI(torch.Tensor(x_train), | ||||||
|  |         #                                torch.Tensor(y_train)), | ||||||
|  |         prototype_initializer=cinit.UniformInitializer(2), | ||||||
|  |         input_dim=x_train.shape[1], | ||||||
|  |         lr=0.1, | ||||||
|  |         #transfer_function="sigmoid_beta", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Initialize the model | ||||||
|  |     model = GRLVQ(hparams) | ||||||
|  |  | ||||||
|  |     # Model summary | ||||||
|  |     print(model) | ||||||
|  |  | ||||||
|  |     # Callbacks | ||||||
|  |     vis = VisSiameseGLVQ2D(x_train, y_train) | ||||||
|  |     debug = PrintRelevanceCallback() | ||||||
|  |  | ||||||
|  |     # Setup trainer | ||||||
|  |     trainer = pl.Trainer(max_epochs=200, callbacks=[vis, debug]) | ||||||
|  |  | ||||||
|  |     # Training loop | ||||||
|  |     trainer.fit(model, train_loader) | ||||||
							
								
								
									
										57
									
								
								examples/grlvq_spiral.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								examples/grlvq_spiral.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | |||||||
|  | """GMLVQ example using all four dimensions of the Iris dataset.""" | ||||||
|  |  | ||||||
|  | import pytorch_lightning as pl | ||||||
|  | import torch | ||||||
|  | from prototorch.components import initializers as cinit | ||||||
|  | from prototorch.datasets.abstract import NumpyDataset | ||||||
|  | from sklearn.datasets import load_iris | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  |  | ||||||
|  | from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D | ||||||
|  | from prototorch.models.glvq import GRLVQ | ||||||
|  |  | ||||||
|  | from sklearn.preprocessing import StandardScaler | ||||||
|  |  | ||||||
|  | from prototorch.datasets.spiral import make_spiral | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PrintRelevanceCallback(pl.Callback): | ||||||
|  |     def on_epoch_end(self, trainer, pl_module: GRLVQ): | ||||||
|  |         print(pl_module.relevance_profile) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     # Dataset | ||||||
|  |     x_train, y_train = make_spiral(n_samples=1000, noise=0.3) | ||||||
|  |     train_ds = NumpyDataset(x_train, y_train) | ||||||
|  |  | ||||||
|  |     # Dataloaders | ||||||
|  |     train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) | ||||||
|  |  | ||||||
|  |     # Hyperparameters | ||||||
|  |     hparams = dict( | ||||||
|  |         nclasses=2, | ||||||
|  |         prototypes_per_class=20, | ||||||
|  |         prototype_initializer=cinit.SSI(torch.Tensor(x_train), | ||||||
|  |                                         torch.Tensor(y_train)), | ||||||
|  |         #prototype_initializer=cinit.UniformInitializer(2), | ||||||
|  |         input_dim=x_train.shape[1], | ||||||
|  |         lr=0.1, | ||||||
|  |         #transfer_function="sigmoid_beta", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Initialize the model | ||||||
|  |     model = GRLVQ(hparams) | ||||||
|  |  | ||||||
|  |     # Model summary | ||||||
|  |     print(model) | ||||||
|  |  | ||||||
|  |     # Callbacks | ||||||
|  |     vis = VisSiameseGLVQ2D(x_train, y_train) | ||||||
|  |     debug = PrintRelevanceCallback() | ||||||
|  |  | ||||||
|  |     # Setup trainer | ||||||
|  |     trainer = pl.Trainer(max_epochs=200, callbacks=[vis, debug]) | ||||||
|  |  | ||||||
|  |     # Training loop | ||||||
|  |     trainer.fit(model, train_loader) | ||||||
| @@ -3,7 +3,7 @@ import torchmetrics | |||||||
| from prototorch.components import LabeledComponents | from prototorch.components import LabeledComponents | ||||||
| from prototorch.functions.activations import get_activation | from prototorch.functions.activations import get_activation | ||||||
| from prototorch.functions.competitions import wtac | from prototorch.functions.competitions import wtac | ||||||
| from prototorch.functions.distances import (euclidean_distance, | from prototorch.functions.distances import (euclidean_distance, omega_distance, | ||||||
|                                             squared_euclidean_distance) |                                             squared_euclidean_distance) | ||||||
| from prototorch.functions.losses import glvq_loss | from prototorch.functions.losses import glvq_loss | ||||||
|  |  | ||||||
| @@ -32,7 +32,7 @@ class GLVQ(AbstractPrototypeModel): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def prototype_labels(self): |     def prototype_labels(self): | ||||||
|         return self.proto_layer.component_labels.detach().numpy() |         return self.proto_layer.component_labels.detach().cpu() | ||||||
|  |  | ||||||
|     def forward(self, x): |     def forward(self, x): | ||||||
|         protos, _ = self.proto_layer() |         protos, _ = self.proto_layer() | ||||||
| @@ -148,6 +148,41 @@ class SiameseGLVQ(GLVQ): | |||||||
|         return y_pred.numpy() |         return y_pred.numpy() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class GRLVQ(GLVQ): | ||||||
|  |     """Generalized Relevance Learning Vector Quantization.""" | ||||||
|  |     def __init__(self, hparams, **kwargs): | ||||||
|  |         super().__init__(hparams, **kwargs) | ||||||
|  |         self.relevances = torch.nn.parameter.Parameter( | ||||||
|  |             torch.ones(self.hparams.input_dim)) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         protos, _ = self.proto_layer() | ||||||
|  |         dis = omega_distance(x, protos, torch.diag(self.relevances)) | ||||||
|  |         return dis | ||||||
|  |  | ||||||
|  |     def backbone(self, x): | ||||||
|  |         return x @ torch.diag(self.relevances) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def relevance_profile(self): | ||||||
|  |         return self.relevances.detach().cpu() | ||||||
|  |  | ||||||
|  |     def predict_latent(self, x): | ||||||
|  |         """Predict `x` assuming it is already embedded in the latent space. | ||||||
|  |  | ||||||
|  |         Only the prototypes are embedded in the latent space using the | ||||||
|  |         backbone. | ||||||
|  |  | ||||||
|  |         """ | ||||||
|  |         # model.eval()  # ?! | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             protos, plabels = self.proto_layer() | ||||||
|  |             latent_protos = protos @ torch.diag(self.relevances) | ||||||
|  |             d = squared_euclidean_distance(x, latent_protos) | ||||||
|  |             y_pred = wtac(d, plabels) | ||||||
|  |         return y_pred.numpy() | ||||||
|  |  | ||||||
|  |  | ||||||
| class GMLVQ(GLVQ): | class GMLVQ(GLVQ): | ||||||
|     """Generalized Matrix Learning Vector Quantization.""" |     """Generalized Matrix Learning Vector Quantization.""" | ||||||
|     def __init__(self, hparams, **kwargs): |     def __init__(self, hparams, **kwargs): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user