Add siamese glvq

This commit is contained in:
Jensun Ravichandran 2021-04-27 14:35:17 +02:00
parent 8d57f69c9e
commit 1fb197077c
3 changed files with 170 additions and 12 deletions

View File

@ -43,17 +43,18 @@ To assist in the development process, you may also find it useful to install
## Available models ## Available models
- [X] GLVQ - GLVQ
- [X] Neural Gas - Siamese GLVQ
- Neural Gas
## Work in Progress ## Work in Progress
- [ ] CBC - CBC
## Planned models ## Planned models
- [ ] GMLVQ - GMLVQ
- [ ] Local-Matrix GMLVQ - Local-Matrix GMLVQ
- [ ] Limited-Rank GMLVQ - Limited-Rank GMLVQ
- [ ] GTLVQ - GTLVQ
- [ ] RSLVQ - RSLVQ
- [ ] PLVQ - PLVQ
- [ ] LVQMLN - LVQMLN

View File

@ -0,0 +1,113 @@
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
protos = pl_module.backbone(torch.Tensor(protos)).detach()
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 0.2, x[:, 0].max() + 0.2
y_min, y_max = x[:, 1].min() - 0.2, x[:, 1].max() + 0.2
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
tb = pl_module.logger.experiment
tb.add_figure(tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False)
plt.pause(0.1)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.latent_size = latent_size
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.dense2(self.relu(self.dense1(x))))
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
lr=0.01,
)
# Initialize the model
model = SiameseGLVQ(hparams,
backbone_module=Backbone,
data=[x_train, y_train])
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
# Training loop
trainer.fit(model, train_loader)

View File

@ -68,8 +68,8 @@ class GLVQ(pl.LightningModule):
# self.log("train_acc_epoch", self.train_acc.compute()) # self.log("train_acc_epoch", self.train_acc.compute())
def predict(self, x): def predict(self, x):
with torch.no_grad():
# model.eval() # ?! # model.eval() # ?!
with torch.no_grad():
d = self(x) d = self(x)
plabels = self.proto_layer.prototype_labels plabels = self.proto_layer.prototype_labels
y_pred = wtac(d, plabels) y_pred = wtac(d, plabels)
@ -77,8 +77,52 @@ class GLVQ(pl.LightningModule):
class ImageGLVQ(GLVQ): class ImageGLVQ(GLVQ):
"""GLVQ model that constrains the prototypes to the range [0, 1] by """GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by
clamping after updates. clamping after updates.
""" """
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.prototypes.data.clamp_(0.0, 1.0) self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting.
GLVQ model that applies an arbitrary transformation on the inputs and the
prototypes before computing the distances between them. The weights in the
transformation pipeline are only learned from the inputs.
"""
def __init__(self,
hparams,
backbone_module=torch.nn.Identity,
backbone_params={},
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
self.backbone_dependent = backbone_module(
**backbone_params).requires_grad_(False)
def sync_backbones(self):
master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True)
def forward(self, x):
self.sync_backbones()
protos = self.proto_layer.prototypes
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
dis = euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
# model.eval() # ?!
with torch.no_grad():
protos = self.proto_layer.prototypes
latent_protos = self.backbone_dependent(protos)
d = euclidean_distance(x, latent_protos)
plabels = self.proto_layer.prototype_labels
y_pred = wtac(d, plabels)
return y_pred.numpy()