Add GMLVQ examples

This commit is contained in:
Jensun Ravichandran 2021-05-04 15:11:16 +02:00
parent a1ac5a70c7
commit f402eea884
4 changed files with 117 additions and 4 deletions

View File

@ -46,13 +46,13 @@ To assist in the development process, you may also find it useful to install
- GLVQ
- Siamese GLVQ
- Neural Gas
- GMLVQ
- Limited-Rank GMLVQ
## Work in Progress
- CBC
- LVQMLN
- GMLVQ
- Limited-Rank GMLVQ
## Planned models
@ -62,3 +62,4 @@ To assist in the development process, you may also find it useful to install
- PLVQ
- SILVQ
- KNN
- LVQ1

47
examples/gmlvq_iris.py Normal file
View File

@ -0,0 +1,47 @@
"""GMLVQ example using all four dimensions of the Iris dataset."""
import pytorch_lightning as pl
import torch
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import GMLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
nclasses=3,
prototypes_per_class=1,
prototype_initializer=cinit.SMI(torch.Tensor(x_train),
torch.Tensor(y_train)),
input_dim=x_train.shape[1],
latent_dim=2,
lr=0.01,
)
# Initialize the model
model = GMLVQ(hparams)
# Model summary
print(model)
# Callbacks
vis = VisSiameseGLVQ2D(x_train, y_train)
# Namespace hook for the visualization to work
model.backbone = model.omega_layer
# Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
# Training loop
trainer.fit(model, train_loader)

47
examples/gmlvq_tecator.py Normal file
View File

@ -0,0 +1,47 @@
"""GMLVQ example using the Tecator dataset."""
import pytorch_lightning as pl
import torch
from prototorch.components import initializers as cinit
from prototorch.datasets.tecator import Tecator
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import GMLVQ
from torch.utils.data import DataLoader
if __name__ == "__main__":
# Dataset
train_ds = Tecator(root="./datasets/", train=True)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=32)
# Grab the full dataset to warm-start prototypes
x, y = next(iter(DataLoader(train_ds, batch_size=len(train_ds))))
# Hyperparameters
hparams = dict(
nclasses=2,
prototypes_per_class=2,
prototype_initializer=cinit.SMI(x, y),
input_dim=x.shape[1],
latent_dim=2,
lr=0.01,
)
# Initialize the model
model = GMLVQ(hparams)
# Model summary
print(model)
# Callbacks
vis = VisSiameseGLVQ2D(x, y)
# Namespace hook for the visualization to work
model.backbone = model.omega_layer
# Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
# Training loop
trainer.fit(model, train_loader)

View File

@ -94,11 +94,13 @@ class SiameseGLVQ(GLVQ):
hparams,
backbone_module=torch.nn.Identity,
backbone_params={},
sync=True,
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
self.backbone_dependent = backbone_module(
**backbone_params).requires_grad_(False)
self.sync = sync
def sync_backbones(self):
master_state = self.backbone.state_dict()
@ -117,7 +119,8 @@ class SiameseGLVQ(GLVQ):
return proto_opt
def forward(self, x):
self.sync_backbones()
if self.sync:
self.sync_backbones()
protos, _ = self.proto_layer()
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
@ -145,7 +148,7 @@ class GMLVQ(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.omega_layer = torch.nn.Linear(self.hparams.input_dim,
self.latent_dim,
self.hparams.latent_dim,
bias=False)
def forward(self, x):
@ -155,6 +158,21 @@ class GMLVQ(GLVQ):
dis = squared_euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
# model.eval() # ?!
with torch.no_grad():
protos, plabels = self.proto_layer()
latent_protos = self.omega_layer(protos)
d = squared_euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
class LVQMLN(GLVQ):
"""Learning Vector Quantization Multi-Layer Network.