Use Components instead of Prototypes and refactor old examples

This commit is contained in:
Jensun Ravichandran 2021-04-29 17:05:41 +02:00
parent eeb684b3b6
commit a16bebd0c4
7 changed files with 216 additions and 235 deletions

View File

@ -35,7 +35,7 @@ workon pt
git clone git@github.com:si-cim/prototorch_models.git
cd prototorch_models
git checkout dev
pip install -e .[all] # \[all\] if you are using zsh
pip install -e .[all] # \[all\] if you are using zsh or MacOS
```
To assist in the development process, you may also find it useful to install

View File

@ -1,63 +1,14 @@
"""GLVQ example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisGLVQ2D
from prototorch.models.glvq import GLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import GLVQ
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
@ -69,24 +20,21 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
prototypes_per_class=3,
prototype_initializer="stratified_mean",
lr=0.1,
prototypes_per_class=2,
prototype_initializer=cinit.StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01,
)
# Initialize the model
model = GLVQ(hparams, data=[x_train, y_train])
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(max_epochs=50, callbacks=[vis])
trainer = pl.Trainer(
max_epochs=50,
callbacks=[VisGLVQ2D(x_train, y_train)],
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -3,63 +3,13 @@
import numpy as np
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisNG2D
from prototorch.models.neural_gas import NeuralGas
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.neural_gas import NeuralGas
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Neural Gas Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module: NeuralGas):
protos = pl_module.proto_layer.prototypes.detach().cpu().numpy()
cmat = pl_module.topology_layer.cmat.cpu().numpy()
# Visualize the data and the prototypes
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(self.x_train[:, 0],
self.x_train[:, 1],
c=self.y_train,
edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c="k",
edgecolor="k",
marker="D",
s=50,
)
# Draw connections
for i in range(len(protos)):
for j in range(len(protos)):
if cmat[i][j]:
ax.plot(
[protos[i, 0], protos[j, 0]],
[protos[i, 1], protos[j, 1]],
"k-",
)
plt.pause(0.01)
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
@ -68,7 +18,6 @@ if __name__ == "__main__":
scaler.fit(x_train)
x_train = scaler.transform(x_train)
y_single_class = np.zeros_like(y_train)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
@ -77,20 +26,18 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=1,
prototypes_per_class=30,
prototype_initializer="rand",
lr=0.1,
num_prototypes=30,
lr=0.01,
)
# Initialize the model
model = NeuralGas(hparams, data=[x_train, y_single_class])
model = NeuralGas(hparams)
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
vis = VisNG2D(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(

View File

@ -1,70 +1,15 @@
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.components import (StratifiedMeanInitializer,
StratifiedSelectionInitializer)
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import SiameseGLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
protos = pl_module.backbone(torch.Tensor(protos)).detach()
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 0.2, x[:, 0].max() + 0.2
y_min, y_max = x[:, 1].min() - 0.2, x[:, 1].max() + 0.2
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
tb = pl_module.logger.experiment
tb.add_figure(
tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False,
)
plt.pause(0.1)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
@ -90,23 +35,24 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
prototype_initializer=StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01,
)
# Initialize the model
model = SiameseGLVQ(hparams,
backbone_module=Backbone,
data=[x_train, y_train])
model = SiameseGLVQ(
hparams,
backbone_module=Backbone,
)
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
vis = VisSiameseGLVQ2D(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])

View File

@ -1,16 +1,17 @@
import os
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText
from prototorch.utils.celluloid import Camera
from prototorch.utils.colors import color_scheme
from prototorch.utils.utils import gif_from_dir, make_directory, prettify_string
from prototorch.utils.utils import (gif_from_dir, make_directory,
prettify_string)
class VisWeights(Callback):
class VisWeights(pl.Callback):
"""Abstract weight visualization callback."""
def __init__(
self,
@ -258,3 +259,155 @@ class VisPointProtos(VisWeights):
epoch = trainer.current_epoch
self._display_logs(self.ax, epoch, logs)
self._show_and_save(epoch)
class VisGLVQ2D(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
class VisSiameseGLVQ2D(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
protos = pl_module.backbone(torch.Tensor(protos)).detach()
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
tb = pl_module.logger.experiment
tb.add_figure(
tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False,
)
plt.pause(0.1)
class VisNG2D(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Neural Gas Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
# Visualize the data and the prototypes
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(self.x_train[:, 0],
self.x_train[:, 1],
c=self.y_train,
edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c="k",
edgecolor="k",
marker="D",
s=50,
)
# Draw connections
for i in range(len(protos)):
for j in range(len(protos)):
if cmat[i][j]:
ax.plot(
[protos[i, 0], protos[j, 0]],
[protos[i, 1], protos[j, 1]],
"k-",
)
plt.pause(0.01)

View File

@ -1,14 +1,16 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D
from .abstract import AbstractPrototypeModel
class GLVQ(pl.LightningModule):
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__()
@ -18,29 +20,18 @@ class GLVQ(pl.LightningModule):
# Default Values
self.hparams.setdefault("distance", euclidean_distance)
self.proto_layer = Prototypes1D(
input_dim=self.hparams.input_dim,
nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs)
self.proto_layer = LabeledComponents(
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
initializer=self.hparams.prototype_initializer)
self.train_acc = torchmetrics.Accuracy()
@property
def prototypes(self):
return self.proto_layer.prototypes.detach().numpy()
@property
def prototype_labels(self):
return self.proto_layer.prototype_labels.detach().numpy()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer
return self.proto_layer.component_labels.detach().numpy()
def forward(self, x):
protos = self.proto_layer.prototypes
protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos)
return dis
@ -48,7 +39,7 @@ class GLVQ(pl.LightningModule):
x, y = train_batch
x = x.view(x.size(0), -1)
dis = self(x)
plabels = self.proto_layer.prototype_labels
plabels = self.proto_layer.component_labels
mu = glvq_loss(dis, y, prototype_labels=plabels)
loss = mu.sum(dim=0)
self.log("train_loss", loss)
@ -77,7 +68,7 @@ class GLVQ(pl.LightningModule):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.prototype_labels
plabels = self.proto_layer.component_labels
y_pred = wtac(d, plabels)
return y_pred.numpy()
@ -89,7 +80,7 @@ class ImageGLVQ(GLVQ):
clamping after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
self.proto_layer.components.data.clamp_(0.0, 1.0)
class SiameseGLVQ(GLVQ):
@ -115,7 +106,7 @@ class SiameseGLVQ(GLVQ):
def forward(self, x):
self.sync_backbones()
protos = self.proto_layer.prototypes
protos, _ = self.proto_layer()
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
@ -126,9 +117,8 @@ class SiameseGLVQ(GLVQ):
def predict_latent(self, x):
# model.eval() # ?!
with torch.no_grad():
protos = self.proto_layer.prototypes
protos, plabels = self.proto_layer()
latent_protos = self.backbone_dependent(protos)
d = euclidean_distance(x, latent_protos)
plabels = self.proto_layer.prototype_labels
y_pred = wtac(d, plabels)
return y_pred.numpy()

View File

@ -1,10 +1,13 @@
import pytorch_lightning as pl
import torch
from prototorch.components import Components
from prototorch.components import initializers as cinit
from prototorch.functions.distances import euclidean_distance
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import NeuralGasEnergy
from .abstract import AbstractPrototypeModel
class EuclideanDistance(torch.nn.Module):
def forward(self, x, y):
@ -34,41 +37,35 @@ class ConnectionTopology(torch.nn.Module):
return f"agelimit: {self.agelimit}"
class NeuralGas(pl.LightningModule):
class NeuralGas(AbstractPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("input_dim", 2)
self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1)
self.hparams.setdefault("prototype_initializer", "zeros")
self.hparams.setdefault("prototype_initializer",
cinit.ZerosInitializer(self.hparams.input_dim))
self.proto_layer = Prototypes1D(
input_dim=self.hparams.input_dim,
nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs,
)
self.proto_layer = Components(
self.hparams.num_prototypes,
initializer=self.hparams.prototype_initializer)
self.distance_layer = EuclideanDistance()
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.topology_layer = ConnectionTopology(
agelimit=self.hparams.agelimit,
num_prototypes=len(self.proto_layer.prototypes),
num_prototypes=self.hparams.num_prototypes,
)
def training_step(self, train_batch, batch_idx):
x, _ = train_batch
protos, _ = self.proto_layer()
x = train_batch[0]
protos = self.proto_layer()
d = self.distance_layer(x, protos)
cost, order = self.energy_layer(d)
self.topology_layer(d)
return cost
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer