Update Examples to new initializer architecture.

Visualization still borken for some examples.
This commit is contained in:
Alexander Engelsberger 2021-05-06 14:10:09 +02:00
parent d644114090
commit 1c3613019b
15 changed files with 92 additions and 248 deletions

View File

@ -4,13 +4,12 @@ import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import make_circles
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisPointProtos
from prototorch.models.cbc import CBC, euclidean_similarity
from prototorch.models.glvq import GLVQ
class VisualizationCallback(pl.Callback):
@ -32,7 +31,7 @@ class VisualizationCallback(pl.Callback):
def on_epoch_end(self, trainer, pl_module):
if self.prototype_model:
protos = pl_module.prototypes
protos = pl_module.components
color = pl_module.prototype_labels
else:
protos = pl_module.components
@ -83,8 +82,8 @@ if __name__ == "__main__":
hparams = dict(
input_dim=x_train.shape[1],
nclasses=len(np.unique(y_train)),
prototypes_per_class=5,
prototype_initializer="randn",
num_components=5,
component_initializer=cinit.RandomInitializer(x_train.shape[1]),
lr=0.01,
)
@ -95,31 +94,15 @@ if __name__ == "__main__":
similarity=euclidean_similarity,
)
model = GLVQ(hparams, data=[x_train, y_train])
# Fix the component locations
# model.proto_layer.requires_grad_(False)
# import sys
# sys.exit()
# Model summary
print(model)
# Callbacks
dvis = VisPointProtos(
data=(x_train, y_train),
save=True,
snap=False,
voronoi=True,
resolution=50,
pause_time=0.1,
make_gif=True,
)
dvis = VisualizationCallback(x_train,
y_train,
prototype_model=False,
title="CBC Circle Example")
# Setup trainer
trainer = pl.Trainer(
max_epochs=10,
max_epochs=50,
callbacks=[
dvis,
],

View File

@ -4,30 +4,38 @@ import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.cbc import CBC
from prototorch.models.cbc import CBC, euclidean_similarity
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
def __init__(
self,
x_train,
y_train,
prototype_model=True,
title="Prototype Visualization",
cmap="viridis",
):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
self.prototype_model = prototype_model
def on_epoch_end(self, trainer, pl_module):
# protos = pl_module.prototypes
protos = pl_module.components
# plabels = pl_module.prototype_labels
if self.prototype_model:
protos = pl_module.components
color = pl_module.prototype_labels
else:
protos = pl_module.components
color = "k"
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
@ -37,8 +45,7 @@ class VisualizationCallback(pl.Callback):
ax.scatter(
protos[:, 0],
protos[:, 1],
# c=plabels,
c="k",
c=color,
cmap=self.cmap,
edgecolor="k",
marker="D",
@ -71,44 +78,33 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
prototypes_per_class=3,
prototype_initializer="stratified_mean",
nclasses=len(np.unique(y_train)),
num_components=9,
component_initializer=cinit.StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01,
)
# Initialize the model
model = CBC(hparams, data=[x_train, y_train])
# Fix the component locations
# model.proto_layer.requires_grad_(False)
# Pure-positive reasonings
ncomps = 3
nclasses = 3
rmat = torch.stack(
[0.9 * torch.eye(ncomps),
torch.zeros(ncomps, nclasses)], dim=0)
# model.reasoning_layer.load_state_dict({"reasoning_probabilities": rmat},
# strict=True)
print(model.reasoning_layer.reasoning_probabilities)
# import sys
# sys.exit()
# Model summary
print(model)
model = CBC(
hparams,
data=[x_train, y_train],
similarity=euclidean_similarity,
)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
dvis = VisualizationCallback(x_train,
y_train,
prototype_model=False,
title="CBC Iris Example")
# Setup trainer
trainer = pl.Trainer(
max_epochs=100,
max_epochs=50,
callbacks=[
vis,
dvis,
],
)
# Training loop
trainer.fit(model, train_loader)
trainer.fit(model, train_loader)

View File

@ -4,9 +4,9 @@ import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.cbc import CBC
@ -110,7 +110,7 @@ if __name__ == "__main__":
# Pure-positive reasonings
new_reasoning = torch.zeros_like(
model.reasoning_layer.reasoning_probabilities)
for i, label in enumerate(model.proto_layer.prototype_labels):
for i, label in enumerate(model.component_layer.prototype_labels):
new_reasoning[0][0][i][int(label)] = 1.0
model.reasoning_layer.reasoning_probabilities.data = new_reasoning

View File

@ -8,9 +8,9 @@ import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.cbc import CBC
from prototorch.models.glvq import GLVQ
@ -132,11 +132,12 @@ if __name__ == "__main__":
train(glvq_model, x_train, y_train, train_loader, epochs=10)
# Transfer Prototypes
cbc_model.proto_layer.load_state_dict(glvq_model.proto_layer.state_dict())
cbc_model.component_layer.load_state_dict(
glvq_model.proto_layer.state_dict())
# Pure-positive reasonings
new_reasoning = torch.zeros_like(
cbc_model.reasoning_layer.reasoning_probabilities)
for i, label in enumerate(cbc_model.proto_layer.prototype_labels):
for i, label in enumerate(cbc_model.component_layer.prototype_labels):
new_reasoning[0][0][i][int(label)] = 1.0
new_reasoning[1][0][i][1 - int(label)] = 1.0

View File

@ -1,86 +1,16 @@
"""GLVQ example using the Iris dataset."""
import argparse
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisGLVQ2D
from prototorch.models.glvq import GLVQ
class GLVQIris(GLVQ):
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser],
add_help=False)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-1)
parser.add_argument("--batch_size", type=int, default=150)
parser.add_argument("--input_dim", type=int, default=2)
parser.add_argument("--nclasses", type=int, default=3)
parser.add_argument("--prototypes_per_class", type=int, default=3)
parser.add_argument("--prototype_initializer",
type=str,
default="stratified_mean")
return parser
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
if __name__ == "__main__":
# For best-practices when using `argparse` with `pytorch_lightning`, see
# https://pytorch-lightning.readthedocs.io/en/stable/common/hyperparameters.html
parser = argparse.ArgumentParser()
# Dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
@ -89,43 +19,23 @@ if __name__ == "__main__":
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Add model specific args
parser = GLVQIris.add_model_specific_args(parser)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
# Automatically add trainer-specific-args like `--gpus`, `--num_nodes` etc.
parser = pl.Trainer.add_argparse_args(parser)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
parser,
max_epochs=10,
callbacks=[
vis,
], # comment this line out to disable the visualization
# Hyperparameters
hparams = dict(
nclasses=3,
prototypes_per_class=2,
prototype_initializer=cinit.StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01,
)
# trainer.tune(model)
# Initialize the model
args = parser.parse_args()
model = GLVQIris(args, data=[x_train, y_train])
model = GLVQ(hparams)
# Model summary
print(model)
# Setup trainer
trainer = pl.Trainer(
max_epochs=50,
callbacks=[VisGLVQ2D(x_train, y_train)],
)
# Training loop
trainer.fit(model, train_loader)
# Save the model manually (use `pl.callbacks.ModelCheckpoint` to automate)
ckpt = "glvq_iris.ckpt"
trainer.save_checkpoint(ckpt)
# Load the checkpoint
new_model = GLVQIris.load_from_checkpoint(checkpoint_path=ckpt)
print(new_model)
# Continue training
trainer.fit(new_model, train_loader) # TODO See why this fails!

View File

@ -1,40 +0,0 @@
"""GLVQ example using the Iris dataset."""
import pytorch_lightning as pl
import torch
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisGLVQ2D
from prototorch.models.glvq import GLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
nclasses=3,
prototypes_per_class=2,
prototype_initializer=cinit.StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01,
)
# Initialize the model
model = GLVQ(hparams)
# Setup trainer
trainer = pl.Trainer(
max_epochs=50,
callbacks=[VisGLVQ2D(x_train, y_train)],
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -7,6 +7,7 @@ import argparse
import pytorch_lightning as pl
import torchvision
from prototorch.components import initializers as cinit
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
@ -92,12 +93,12 @@ if __name__ == "__main__":
input_dim=28 * 28,
nclasses=10,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
prototype_initializer=cinit.StratifiedMeanInitializer(x, y),
lr=args.lr,
)
# Initialize the model
model = ImageGLVQ(hparams, data=[x, y])
model = ImageGLVQ(hparams)
# Model summary
print(model)

View File

@ -5,9 +5,10 @@ import torch
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from prototorch.datasets.spiral import make_spiral
from torch.utils.data import DataLoader
from prototorch.models.callbacks.visualization import VisGLVQ2D
from prototorch.models.glvq import GLVQ
from torch.utils.data import DataLoader
class StopOnNaN(pl.Callback):

View File

@ -4,11 +4,12 @@ import pytorch_lightning as pl
import torch
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import GMLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import GMLVQ
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)

View File

@ -1,12 +1,12 @@
"""GMLVQ example using the Tecator dataset."""
import pytorch_lightning as pl
import torch
from prototorch.components import initializers as cinit
from prototorch.datasets.tecator import Tecator
from torch.utils.data import DataLoader
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import GMLVQ
from torch.utils.data import DataLoader
if __name__ == "__main__":
# Dataset

View File

@ -1,15 +1,14 @@
"""Neural Gas example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisNG2D
from prototorch.models.neural_gas import NeuralGas
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader
from prototorch.models.callbacks.visualization import VisNG2D
from prototorch.models.neural_gas import NeuralGas
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)

View File

@ -2,14 +2,16 @@
import pytorch_lightning as pl
import torch
from prototorch.components import (StratifiedMeanInitializer,
StratifiedSelectionInitializer)
from prototorch.components import (
StratifiedMeanInitializer
)
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import SiameseGLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import SiameseGLVQ
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):

View File

@ -1,10 +1,9 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components.components import Components
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity
from prototorch.modules.prototypes import Prototypes1D
def rescaled_cosine_similarity(x, y):
@ -93,12 +92,8 @@ class CBC(pl.LightningModule):
super().__init__()
self.save_hyperparameters(hparams)
self.margin = margin
self.proto_layer = Prototypes1D(
input_dim=self.hparams.input_dim,
nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs)
self.component_layer = Components(self.hparams.num_components,
self.hparams.component_initializer)
# self.similarity = CosineSimilarity()
self.similarity = similarity
self.backbone = backbone_class()
@ -110,7 +105,7 @@ class CBC(pl.LightningModule):
@property
def components(self):
return self.proto_layer.prototypes.detach().cpu()
return self.component_layer.components.detach().cpu()
@property
def reasonings(self):
@ -126,7 +121,7 @@ class CBC(pl.LightningModule):
def forward(self, x):
self.sync_backbones()
protos, _ = self.proto_layer()
protos = self.component_layer()
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
@ -167,4 +162,4 @@ class ImageCBC(CBC):
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
self.component_layer.prototypes.data.clamp_(0.0, 1.0)

View File

@ -1,4 +1,3 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components import LabeledComponents
@ -7,7 +6,6 @@ from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance,
squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D
from .abstract import AbstractPrototypeModel
@ -55,7 +53,6 @@ class GLVQ(AbstractPrototypeModel):
with torch.no_grad():
preds = wtac(dis, plabels)
# `.int()` because FloatTensors are assumed to be class probabilities
self.train_acc(preds.int(), y.int())
# Logging
self.log("train_loss", loss)

View File

@ -1,9 +1,7 @@
import pytorch_lightning as pl
import torch
from prototorch.components import Components
from prototorch.components import initializers as cinit
from prototorch.functions.distances import euclidean_distance
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import NeuralGasEnergy
from .abstract import AbstractPrototypeModel