feat: Improve 2D visualization with Voronoi Cells

This commit is contained in:
Alexander Engelsberger 2021-10-15 13:01:01 +02:00
parent 967953442b
commit d1985571b3
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
9 changed files with 109 additions and 238 deletions

View File

@ -38,10 +38,12 @@ if __name__ == "__main__":
) )
# Callbacks # Callbacks
vis = pt.models.VisCBC2D(data=train_ds, vis = pt.models.Visualize2DVoronoiCallback(
title="CBC Iris Example", data=train_ds,
resolution=100, title="CBC Iris Example",
axis_off=True) resolution=100,
axis_off=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(

View File

@ -3,7 +3,7 @@
import argparse import argparse
import prototorch as pt import prototorch as pt
import prototorch.models.expanded import prototorch.models.clcc
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
@ -30,7 +30,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = prototorch.models.expanded.GLVQ( model = prototorch.models.GLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
@ -42,7 +42,13 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds) vis = pt.models.Visualize2DVoronoiCallback(
data=train_ds,
resolution=200,
title="Example: GLVQ on Iris",
x_label="sepal length",
y_label="petal length",
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(

View File

View File

@ -7,7 +7,7 @@ from prototorch.core.components import LabeledComponents
from prototorch.core.distances import euclidean_distance from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import AbstractComponentsInitializer, LabelsInitializer from prototorch.core.initializers import AbstractComponentsInitializer, LabelsInitializer
from prototorch.core.losses import GLVQLoss from prototorch.core.losses import GLVQLoss
from prototorch.models.expanded.clcc_scheme import CLCCScheme from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.nn.wrappers import LambdaLayer from prototorch.nn.wrappers import LambdaLayer

View File

@ -9,6 +9,7 @@ CLCC is a LVQ scheme containing 4 steps
""" """
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
class CLCCScheme(pl.LightningModule): class CLCCScheme(pl.LightningModule):
@ -36,6 +37,8 @@ class CLCCScheme(pl.LightningModule):
return comparison_tensor return comparison_tensor
def forward(self, batch): def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes? # TODO: manage different datatypes?
components = self.components_layer() components = self.components_layer()
# TODO: => Component Hook # TODO: => Component Hook
@ -43,6 +46,12 @@ class CLCCScheme(pl.LightningModule):
# TODO: => Competition Hook # TODO: => Competition Hook
return self.inference(comparison_tensor, components) return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def loss_forward(self, batch): def loss_forward(self, batch):
# TODO: manage different datatypes? # TODO: manage different datatypes?
components = self.components_layer() components = self.components_layer()

View File

@ -3,12 +3,12 @@ import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
from prototorch.models.expanded.clcc_glvq import GLVQ, GLVQhparams from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams
from torch.utils.data import DataLoader, Dataset from prototorch.models.vis import Visualize2DVoronoiCallback
from torchvision import datasets
from torchvision.transforms import Compose, Lambda, ToTensor
plt.gray() # NEW STUFF
# ##############################################################################
# ##############################################################################
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
@ -29,7 +29,8 @@ if __name__ == "__main__":
print(model) print(model)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds) vis = Visualize2DVoronoiCallback(data=train_ds, resolution=500)
# Train # Train
trainer = pl.Trainer(callbacks=[vis], gpus=1) trainer = pl.Trainer(callbacks=[vis], gpus=1, max_epochs=100)
trainer.fit(model, train_loader) trainer.fit(model, train_loader)

View File

@ -1 +0,0 @@
from .glvq import GLVQ

View File

@ -1,164 +0,0 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.core.competitions import WTAC, wtac
from prototorch.core.components import Components, LabeledComponents
from prototorch.core.distances import (
euclidean_distance,
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.core.initializers import EyeTransformInitializer, LabelsInitializer
from prototorch.core.losses import GLVQLoss, lvq1_loss, lvq21_loss
from prototorch.core.pooling import stratified_min_pooling
from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
class GLVQ(pl.LightningModule):
def __init__(self, hparams, **kwargs):
super().__init__()
# Hyperparameters
self.save_hyperparameters(hparams)
# Default hparams
# TODO: Manage by an HPARAMS Object
self.hparams.setdefault("lr", 0.01)
self.hparams.setdefault("margin", 0.0)
self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
# Default config
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
distance_fn = kwargs.get("distance_fn", euclidean_distance)
prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer())
if prototypes_initializer is not None:
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
self.distance_layer = LambdaLayer(distance_fn)
self.competition_layer = WTAC()
self.loss = GLVQLoss(
margin=self.hparams.margin,
transfer_fn=self.hparams.transfer_fn,
beta=self.hparams.transfer_beta,
)
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
if self.lr_scheduler is not None:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
else:
return optimizer
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x)
_, plabels = self.proto_layer()
loss = self.loss(out, y, plabels)
return out, loss
def training_step(self, batch, batch_idx, optimizer_idx=None):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log_prototype_win_ratios(out)
self.log("train_loss", train_loss)
self.log_acc(out, batch[-1], tag="train_acc")
return train_loss
def validation_step(self, batch, batch_idx):
out, val_loss = self.shared_step(batch, batch_idx)
self.log("val_loss", val_loss)
self.log_acc(out, batch[-1], tag="val_acc")
return val_loss
def test_step(self, batch, batch_idx):
out, test_loss = self.shared_step(batch, batch_idx)
self.log_acc(out, batch[-1], tag="test_acc")
return test_loss
def test_epoch_end(self, outputs):
test_loss = 0.0
for batch_loss in outputs:
test_loss += batch_loss.item()
self.log("test_loss", test_loss)
# API
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self.compute_distances(x)
_, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels)
y_pred = torch.nn.functional.softmin(winning)
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
_, plabels = self.proto_layer()
y_pred = self.competition_layer(distances, plabels)
return y_pred
def predict(self, x):
with torch.no_grad():
distances = self.compute_distances(x)
y_pred = self.predict_from_distances(distances)
return y_pred
@property
def prototype_labels(self):
return self.proto_layer.labels.detach().cpu()
@property
def num_classes(self):
return self.proto_layer.num_classes
@property
def num_prototypes(self):
return len(self.proto_layer.components)
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
# Python overwrites
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
wrapped = f"ProtoTorch Bolt(\n{indented})"
return wrapped

View File

@ -5,14 +5,18 @@ import pytorch_lightning as pl
import torch import torch
import torchvision import torchvision
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.utils.utils import mesh2d from prototorch.utils.utils import generate_mesh, mesh2d
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
COLOR_UNLABELED = 'w'
class Vis2DAbstract(pl.Callback): class Vis2DAbstract(pl.Callback):
def __init__(self, def __init__(self,
data, data,
title="Prototype Visualization", title=None,
x_label=None,
y_label=None,
cmap="viridis", cmap="viridis",
border=0.1, border=0.1,
resolution=100, resolution=100,
@ -44,6 +48,8 @@ class Vis2DAbstract(pl.Callback):
self.y_train = y self.y_train = y
self.title = title self.title = title
self.x_label = x_label
self.y_label = y_label
self.fig = plt.figure(self.title) self.fig = plt.figure(self.title)
self.cmap = cmap self.cmap = cmap
self.border = border self.border = border
@ -56,20 +62,19 @@ class Vis2DAbstract(pl.Callback):
self.pause_time = pause_time self.pause_time = pause_time
self.block = block self.block = block
def precheck(self, trainer): def show_on_current_epoch(self, trainer):
if self.show_last_only: if self.show_last_only and trainer.current_epoch != trainer.max_epochs - 1:
if trainer.current_epoch != trainer.max_epochs - 1: return False
return False
return True return True
def setup_ax(self, xlabel=None, ylabel=None): def setup_ax(self):
ax = self.fig.gca() ax = self.fig.gca()
ax.cla() ax.cla()
ax.set_title(self.title) ax.set_title(self.title)
if xlabel: if self.x_label:
ax.set_xlabel("Data dimension 1") ax.set_xlabel(self.x_label)
if ylabel: if self.x_label:
ax.set_ylabel("Data dimension 2") ax.set_ylabel(self.y_label)
if self.axis_off: if self.axis_off:
ax.axis("off") ax.axis("off")
return ax return ax
@ -116,27 +121,64 @@ class Vis2DAbstract(pl.Callback):
plt.close() plt.close()
class VisGLVQ2D(Vis2DAbstract): class Visualize2DVoronoiCallback(Vis2DAbstract):
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
self.data_min = torch.min(self.x_train, axis=0).values
self.data_max = torch.max(self.x_train, axis=0).values
def current_span(self, proto_values):
proto_min = torch.min(proto_values, axis=0).values
proto_max = torch.max(proto_values, axis=0).values
overall_min = torch.minimum(proto_min, self.data_min)
overall_max = torch.maximum(proto_max, self.data_max)
return overall_min, overall_max
def get_voronoi_diagram(self, min, max, model):
mesh_input, (xx, yy) = generate_mesh(
min,
max,
border=self.border,
resolution=self.resolution,
device=model.device,
)
y_pred = model.predict(mesh_input)
return xx, yy, y_pred.reshape(xx.shape)
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer): if not self.show_on_current_epoch(trainer):
return True return True
protos = pl_module.prototypes # Extract Prototypes
plabels = pl_module.prototype_labels proto_values = pl_module.prototypes
x_train, y_train = self.x_train, self.y_train if hasattr(pl_module, "prototype_labels"):
ax = self.setup_ax(xlabel="Data dimension 1", proto_labels = pl_module.prototype_labels
ylabel="Data dimension 2") else:
self.plot_data(ax, x_train, y_train) proto_labels = COLOR_UNLABELED
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos)) # Calculate Voronoi Diagram
mesh_input, xx, yy = mesh2d(x, overall_min, overall_max = self.current_span(proto_values)
self.border, xx, yy, y_pred = self.get_voronoi_diagram(
self.resolution, overall_min,
device=pl_module.device) overall_max,
mesh_input = (mesh_input, None) pl_module,
y_pred = pl_module(mesh_input) )
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx.cpu(), yy.cpu(), y_pred, cmap=self.cmap, alpha=0.35) ax = self.setup_ax()
ax.contourf(
xx.cpu(),
yy.cpu(),
y_pred.cpu(),
cmap=self.cmap,
alpha=0.35,
)
self.plot_data(ax, self.x_train, self.y_train)
self.plot_protos(ax, proto_values, proto_labels)
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)
@ -147,7 +189,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
self.map_protos = map_protos self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer): if not self.show_on_current_epoch(trainer):
return True return True
protos = pl_module.prototypes protos = pl_module.prototypes
@ -185,7 +227,7 @@ class VisGMLVQ2D(Vis2DAbstract):
self.ev_proj = ev_proj self.ev_proj = ev_proj
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer): if not self.show_on_current_epoch(trainer):
return True return True
protos = pl_module.prototypes protos = pl_module.prototypes
@ -212,40 +254,16 @@ class VisGMLVQ2D(Vis2DAbstract):
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)
class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos))
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.components_layer._components
y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisNG2D(Vis2DAbstract): class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer): if not self.show_on_current_epoch(trainer):
return True return True
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy() cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax(xlabel="Data dimension 1", ax = self.setup_ax()
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w") self.plot_protos(ax, protos, "w")
@ -316,7 +334,7 @@ class VisImgComp(Vis2DAbstract):
) )
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer): if not self.show_on_current_epoch(trainer):
return True return True
if self.show: if self.show: