Use Components instead of Prototypes and refactor old examples

This commit is contained in:
Jensun Ravichandran
2021-04-29 17:05:41 +02:00
parent eeb684b3b6
commit a16bebd0c4
7 changed files with 216 additions and 235 deletions

View File

@@ -1,63 +1,14 @@
"""GLVQ example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisGLVQ2D
from prototorch.models.glvq import GLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import GLVQ
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
@@ -69,24 +20,21 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
prototypes_per_class=3,
prototype_initializer="stratified_mean",
lr=0.1,
prototypes_per_class=2,
prototype_initializer=cinit.StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01,
)
# Initialize the model
model = GLVQ(hparams, data=[x_train, y_train])
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(max_epochs=50, callbacks=[vis])
trainer = pl.Trainer(
max_epochs=50,
callbacks=[VisGLVQ2D(x_train, y_train)],
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -3,63 +3,13 @@
import numpy as np
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisNG2D
from prototorch.models.neural_gas import NeuralGas
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.neural_gas import NeuralGas
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Neural Gas Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module: NeuralGas):
protos = pl_module.proto_layer.prototypes.detach().cpu().numpy()
cmat = pl_module.topology_layer.cmat.cpu().numpy()
# Visualize the data and the prototypes
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(self.x_train[:, 0],
self.x_train[:, 1],
c=self.y_train,
edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c="k",
edgecolor="k",
marker="D",
s=50,
)
# Draw connections
for i in range(len(protos)):
for j in range(len(protos)):
if cmat[i][j]:
ax.plot(
[protos[i, 0], protos[j, 0]],
[protos[i, 1], protos[j, 1]],
"k-",
)
plt.pause(0.01)
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
@@ -68,7 +18,6 @@ if __name__ == "__main__":
scaler.fit(x_train)
x_train = scaler.transform(x_train)
y_single_class = np.zeros_like(y_train)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
@@ -77,20 +26,18 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=1,
prototypes_per_class=30,
prototype_initializer="rand",
lr=0.1,
num_prototypes=30,
lr=0.01,
)
# Initialize the model
model = NeuralGas(hparams, data=[x_train, y_single_class])
model = NeuralGas(hparams)
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
vis = VisNG2D(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(

View File

@@ -1,70 +1,15 @@
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.components import (StratifiedMeanInitializer,
StratifiedSelectionInitializer)
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import SiameseGLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
protos = pl_module.backbone(torch.Tensor(protos)).detach()
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 0.2, x[:, 0].max() + 0.2
y_min, y_max = x[:, 1].min() - 0.2, x[:, 1].max() + 0.2
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
tb = pl_module.logger.experiment
tb.add_figure(
tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False,
)
plt.pause(0.1)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
@@ -90,23 +35,24 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
prototype_initializer=StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01,
)
# Initialize the model
model = SiameseGLVQ(hparams,
backbone_module=Backbone,
data=[x_train, y_train])
model = SiameseGLVQ(
hparams,
backbone_module=Backbone,
)
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
vis = VisSiameseGLVQ2D(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])