Update Examples to new initializer architecture.

Visualization still borken for some examples.
This commit is contained in:
Alexander Engelsberger 2021-05-06 14:10:09 +02:00
parent d644114090
commit 1c3613019b
15 changed files with 92 additions and 248 deletions

View File

@ -4,13 +4,12 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import make_circles from sklearn.datasets import make_circles
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisPointProtos
from prototorch.models.cbc import CBC, euclidean_similarity from prototorch.models.cbc import CBC, euclidean_similarity
from prototorch.models.glvq import GLVQ
class VisualizationCallback(pl.Callback): class VisualizationCallback(pl.Callback):
@ -32,7 +31,7 @@ class VisualizationCallback(pl.Callback):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if self.prototype_model: if self.prototype_model:
protos = pl_module.prototypes protos = pl_module.components
color = pl_module.prototype_labels color = pl_module.prototype_labels
else: else:
protos = pl_module.components protos = pl_module.components
@ -83,8 +82,8 @@ if __name__ == "__main__":
hparams = dict( hparams = dict(
input_dim=x_train.shape[1], input_dim=x_train.shape[1],
nclasses=len(np.unique(y_train)), nclasses=len(np.unique(y_train)),
prototypes_per_class=5, num_components=5,
prototype_initializer="randn", component_initializer=cinit.RandomInitializer(x_train.shape[1]),
lr=0.01, lr=0.01,
) )
@ -95,31 +94,15 @@ if __name__ == "__main__":
similarity=euclidean_similarity, similarity=euclidean_similarity,
) )
model = GLVQ(hparams, data=[x_train, y_train])
# Fix the component locations
# model.proto_layer.requires_grad_(False)
# import sys
# sys.exit()
# Model summary
print(model)
# Callbacks # Callbacks
dvis = VisPointProtos( dvis = VisualizationCallback(x_train,
data=(x_train, y_train), y_train,
save=True, prototype_model=False,
snap=False, title="CBC Circle Example")
voronoi=True,
resolution=50,
pause_time=0.1,
make_gif=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=10, max_epochs=50,
callbacks=[ callbacks=[
dvis, dvis,
], ],

View File

@ -4,30 +4,38 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset from prototorch.models.cbc import CBC, euclidean_similarity
from prototorch.models.cbc import CBC
class VisualizationCallback(pl.Callback): class VisualizationCallback(pl.Callback):
def __init__(self, def __init__(
x_train, self,
y_train, x_train,
title="Prototype Visualization", y_train,
cmap="viridis"): prototype_model=True,
title="Prototype Visualization",
cmap="viridis",
):
super().__init__() super().__init__()
self.x_train = x_train self.x_train = x_train
self.y_train = y_train self.y_train = y_train
self.title = title self.title = title
self.fig = plt.figure(self.title) self.fig = plt.figure(self.title)
self.cmap = cmap self.cmap = cmap
self.prototype_model = prototype_model
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
# protos = pl_module.prototypes if self.prototype_model:
protos = pl_module.components protos = pl_module.components
# plabels = pl_module.prototype_labels color = pl_module.prototype_labels
else:
protos = pl_module.components
color = "k"
ax = self.fig.gca() ax = self.fig.gca()
ax.cla() ax.cla()
ax.set_title(self.title) ax.set_title(self.title)
@ -37,8 +45,7 @@ class VisualizationCallback(pl.Callback):
ax.scatter( ax.scatter(
protos[:, 0], protos[:, 0],
protos[:, 1], protos[:, 1],
# c=plabels, c=color,
c="k",
cmap=self.cmap, cmap=self.cmap,
edgecolor="k", edgecolor="k",
marker="D", marker="D",
@ -71,44 +78,33 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
input_dim=x_train.shape[1], input_dim=x_train.shape[1],
nclasses=3, nclasses=len(np.unique(y_train)),
prototypes_per_class=3, num_components=9,
prototype_initializer="stratified_mean", component_initializer=cinit.StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01, lr=0.01,
) )
# Initialize the model # Initialize the model
model = CBC(hparams, data=[x_train, y_train]) model = CBC(
hparams,
# Fix the component locations data=[x_train, y_train],
# model.proto_layer.requires_grad_(False) similarity=euclidean_similarity,
)
# Pure-positive reasonings
ncomps = 3
nclasses = 3
rmat = torch.stack(
[0.9 * torch.eye(ncomps),
torch.zeros(ncomps, nclasses)], dim=0)
# model.reasoning_layer.load_state_dict({"reasoning_probabilities": rmat},
# strict=True)
print(model.reasoning_layer.reasoning_probabilities)
# import sys
# sys.exit()
# Model summary
print(model)
# Callbacks # Callbacks
vis = VisualizationCallback(x_train, y_train) dvis = VisualizationCallback(x_train,
y_train,
prototype_model=False,
title="CBC Iris Example")
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=100, max_epochs=50,
callbacks=[ callbacks=[
vis, dvis,
], ],
) )
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)

View File

@ -4,9 +4,9 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.cbc import CBC from prototorch.models.cbc import CBC
@ -110,7 +110,7 @@ if __name__ == "__main__":
# Pure-positive reasonings # Pure-positive reasonings
new_reasoning = torch.zeros_like( new_reasoning = torch.zeros_like(
model.reasoning_layer.reasoning_probabilities) model.reasoning_layer.reasoning_probabilities)
for i, label in enumerate(model.proto_layer.prototype_labels): for i, label in enumerate(model.component_layer.prototype_labels):
new_reasoning[0][0][i][int(label)] = 1.0 new_reasoning[0][0][i][int(label)] = 1.0
model.reasoning_layer.reasoning_probabilities.data = new_reasoning model.reasoning_layer.reasoning_probabilities.data = new_reasoning

View File

@ -8,9 +8,9 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.cbc import CBC from prototorch.models.cbc import CBC
from prototorch.models.glvq import GLVQ from prototorch.models.glvq import GLVQ
@ -132,11 +132,12 @@ if __name__ == "__main__":
train(glvq_model, x_train, y_train, train_loader, epochs=10) train(glvq_model, x_train, y_train, train_loader, epochs=10)
# Transfer Prototypes # Transfer Prototypes
cbc_model.proto_layer.load_state_dict(glvq_model.proto_layer.state_dict()) cbc_model.component_layer.load_state_dict(
glvq_model.proto_layer.state_dict())
# Pure-positive reasonings # Pure-positive reasonings
new_reasoning = torch.zeros_like( new_reasoning = torch.zeros_like(
cbc_model.reasoning_layer.reasoning_probabilities) cbc_model.reasoning_layer.reasoning_probabilities)
for i, label in enumerate(cbc_model.proto_layer.prototype_labels): for i, label in enumerate(cbc_model.component_layer.prototype_labels):
new_reasoning[0][0][i][int(label)] = 1.0 new_reasoning[0][0][i][int(label)] = 1.0
new_reasoning[1][0][i][1 - int(label)] = 1.0 new_reasoning[1][0][i][1 - int(label)] = 1.0

View File

@ -1,86 +1,16 @@
"""GLVQ example using the Iris dataset.""" """GLVQ example using the Iris dataset."""
import argparse
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset from prototorch.models.callbacks.visualization import VisGLVQ2D
from prototorch.models.glvq import GLVQ from prototorch.models.glvq import GLVQ
class GLVQIris(GLVQ):
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser],
add_help=False)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-1)
parser.add_argument("--batch_size", type=int, default=150)
parser.add_argument("--input_dim", type=int, default=2)
parser.add_argument("--nclasses", type=int, default=3)
parser.add_argument("--prototypes_per_class", type=int, default=3)
parser.add_argument("--prototype_initializer",
type=str,
default="stratified_mean")
return parser
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
if __name__ == "__main__": if __name__ == "__main__":
# For best-practices when using `argparse` with `pytorch_lightning`, see
# https://pytorch-lightning.readthedocs.io/en/stable/common/hyperparameters.html
parser = argparse.ArgumentParser()
# Dataset # Dataset
x_train, y_train = load_iris(return_X_y=True) x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]] x_train = x_train[:, [0, 2]]
@ -89,43 +19,23 @@ if __name__ == "__main__":
# Dataloaders # Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Add model specific args # Hyperparameters
parser = GLVQIris.add_model_specific_args(parser) hparams = dict(
nclasses=3,
# Callbacks prototypes_per_class=2,
vis = VisualizationCallback(x_train, y_train) prototype_initializer=cinit.StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
# Automatically add trainer-specific-args like `--gpus`, `--num_nodes` etc. lr=0.01,
parser = pl.Trainer.add_argparse_args(parser)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
parser,
max_epochs=10,
callbacks=[
vis,
], # comment this line out to disable the visualization
) )
# trainer.tune(model)
# Initialize the model # Initialize the model
args = parser.parse_args() model = GLVQ(hparams)
model = GLVQIris(args, data=[x_train, y_train])
# Model summary # Setup trainer
print(model) trainer = pl.Trainer(
max_epochs=50,
callbacks=[VisGLVQ2D(x_train, y_train)],
)
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
# Save the model manually (use `pl.callbacks.ModelCheckpoint` to automate)
ckpt = "glvq_iris.ckpt"
trainer.save_checkpoint(ckpt)
# Load the checkpoint
new_model = GLVQIris.load_from_checkpoint(checkpoint_path=ckpt)
print(new_model)
# Continue training
trainer.fit(new_model, train_loader) # TODO See why this fails!

View File

@ -1,40 +0,0 @@
"""GLVQ example using the Iris dataset."""
import pytorch_lightning as pl
import torch
from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisGLVQ2D
from prototorch.models.glvq import GLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
nclasses=3,
prototypes_per_class=2,
prototype_initializer=cinit.StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01,
)
# Initialize the model
model = GLVQ(hparams)
# Setup trainer
trainer = pl.Trainer(
max_epochs=50,
callbacks=[VisGLVQ2D(x_train, y_train)],
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -7,6 +7,7 @@ import argparse
import pytorch_lightning as pl import pytorch_lightning as pl
import torchvision import torchvision
from prototorch.components import initializers as cinit
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
@ -92,12 +93,12 @@ if __name__ == "__main__":
input_dim=28 * 28, input_dim=28 * 28,
nclasses=10, nclasses=10,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer=cinit.StratifiedMeanInitializer(x, y),
lr=args.lr, lr=args.lr,
) )
# Initialize the model # Initialize the model
model = ImageGLVQ(hparams, data=[x, y]) model = ImageGLVQ(hparams)
# Model summary # Model summary
print(model) print(model)

View File

@ -5,9 +5,10 @@ import torch
from prototorch.components import initializers as cinit from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset from prototorch.datasets.abstract import NumpyDataset
from prototorch.datasets.spiral import make_spiral from prototorch.datasets.spiral import make_spiral
from torch.utils.data import DataLoader
from prototorch.models.callbacks.visualization import VisGLVQ2D from prototorch.models.callbacks.visualization import VisGLVQ2D
from prototorch.models.glvq import GLVQ from prototorch.models.glvq import GLVQ
from torch.utils.data import DataLoader
class StopOnNaN(pl.Callback): class StopOnNaN(pl.Callback):

View File

@ -4,11 +4,12 @@ import pytorch_lightning as pl
import torch import torch
from prototorch.components import initializers as cinit from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import GMLVQ
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import GMLVQ
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
x_train, y_train = load_iris(return_X_y=True) x_train, y_train = load_iris(return_X_y=True)

View File

@ -1,12 +1,12 @@
"""GMLVQ example using the Tecator dataset.""" """GMLVQ example using the Tecator dataset."""
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
from prototorch.components import initializers as cinit from prototorch.components import initializers as cinit
from prototorch.datasets.tecator import Tecator from prototorch.datasets.tecator import Tecator
from torch.utils.data import DataLoader
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import GMLVQ from prototorch.models.glvq import GMLVQ
from torch.utils.data import DataLoader
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset

View File

@ -1,15 +1,14 @@
"""Neural Gas example using the Iris dataset.""" """Neural Gas example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisNG2D
from prototorch.models.neural_gas import NeuralGas
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.models.callbacks.visualization import VisNG2D
from prototorch.models.neural_gas import NeuralGas
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
x_train, y_train = load_iris(return_X_y=True) x_train, y_train = load_iris(return_X_y=True)

View File

@ -2,14 +2,16 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.components import (StratifiedMeanInitializer, from prototorch.components import (
StratifiedSelectionInitializer) StratifiedMeanInitializer
)
from prototorch.datasets.abstract import NumpyDataset from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import SiameseGLVQ
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import SiameseGLVQ
class Backbone(torch.nn.Module): class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2): def __init__(self, input_size=4, hidden_size=10, latent_size=2):

View File

@ -1,10 +1,9 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.components.components import Components
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity from prototorch.functions.similarities import cosine_similarity
from prototorch.modules.prototypes import Prototypes1D
def rescaled_cosine_similarity(x, y): def rescaled_cosine_similarity(x, y):
@ -93,12 +92,8 @@ class CBC(pl.LightningModule):
super().__init__() super().__init__()
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
self.margin = margin self.margin = margin
self.proto_layer = Prototypes1D( self.component_layer = Components(self.hparams.num_components,
input_dim=self.hparams.input_dim, self.hparams.component_initializer)
nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs)
# self.similarity = CosineSimilarity() # self.similarity = CosineSimilarity()
self.similarity = similarity self.similarity = similarity
self.backbone = backbone_class() self.backbone = backbone_class()
@ -110,7 +105,7 @@ class CBC(pl.LightningModule):
@property @property
def components(self): def components(self):
return self.proto_layer.prototypes.detach().cpu() return self.component_layer.components.detach().cpu()
@property @property
def reasonings(self): def reasonings(self):
@ -126,7 +121,7 @@ class CBC(pl.LightningModule):
def forward(self, x): def forward(self, x):
self.sync_backbones() self.sync_backbones()
protos, _ = self.proto_layer() protos = self.component_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos) latent_protos = self.backbone_dependent(protos)
@ -167,4 +162,4 @@ class ImageCBC(CBC):
""" """
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) # super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
self.proto_layer.prototypes.data.clamp_(0.0, 1.0) self.component_layer.prototypes.data.clamp_(0.0, 1.0)

View File

@ -1,4 +1,3 @@
import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.components import LabeledComponents from prototorch.components import LabeledComponents
@ -7,7 +6,6 @@ from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, from prototorch.functions.distances import (euclidean_distance,
squared_euclidean_distance) squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D
from .abstract import AbstractPrototypeModel from .abstract import AbstractPrototypeModel
@ -55,7 +53,6 @@ class GLVQ(AbstractPrototypeModel):
with torch.no_grad(): with torch.no_grad():
preds = wtac(dis, plabels) preds = wtac(dis, plabels)
# `.int()` because FloatTensors are assumed to be class probabilities # `.int()` because FloatTensors are assumed to be class probabilities
self.train_acc(preds.int(), y.int())
# Logging # Logging
self.log("train_loss", loss) self.log("train_loss", loss)

View File

@ -1,9 +1,7 @@
import pytorch_lightning as pl
import torch import torch
from prototorch.components import Components from prototorch.components import Components
from prototorch.components import initializers as cinit from prototorch.components import initializers as cinit
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import NeuralGasEnergy from prototorch.modules.losses import NeuralGasEnergy
from .abstract import AbstractPrototypeModel from .abstract import AbstractPrototypeModel