Stop passing component initializers as hparams

Pass the component initializer as an hparam slows down the script very much. The
API has now been changed to pass it as a kwarg to the models instead.

The example scripts have also been updated to reflect the new changes.

Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also
been added.
This commit is contained in:
Jensun Ravichandran
2021-05-12 16:36:22 +02:00
parent 1498c4bde5
commit ca39aa00d5
11 changed files with 172 additions and 21 deletions

View File

@@ -2,7 +2,7 @@ from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC
from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ,
SiameseGLVQ)
ImageGMLVQ, SiameseGLVQ)
from .knn import KNN
from .neural_gas import NeuralGas
from .vis import *

View File

@@ -8,6 +8,11 @@ class AbstractPrototypeModel(pl.LightningModule):
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
@@ -19,3 +24,8 @@ class AbstractPrototypeModel(pl.LightningModule):
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class PrototypeImageModel(pl.LightningModule):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.components.data.clamp_(0.0, 1.0)

View File

@@ -5,9 +5,18 @@ from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
squared_euclidean_distance)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules.mappings import OmegaMapping
from .abstract import AbstractPrototypeModel
from .abstract import AbstractPrototypeModel, PrototypeImageModel
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
from .abstract import AbstractPrototypeModel, PrototypeImageModel
class GLVQ(AbstractPrototypeModel):
@@ -18,6 +27,7 @@ class GLVQ(AbstractPrototypeModel):
self.save_hyperparameters(hparams)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
prototype_initializer = kwargs.get("prototype_initializer", None)
# Default Values
self.hparams.setdefault("distance", euclidean_distance)
@@ -26,7 +36,7 @@ class GLVQ(AbstractPrototypeModel):
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
initializer=self.hparams.prototype_initializer)
initializer=prototype_initializer)
self.transfer_function = get_activation(self.hparams.transfer_function)
self.train_acc = torchmetrics.Accuracy()
@@ -44,7 +54,6 @@ class GLVQ(AbstractPrototypeModel):
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
x, y = train_batch
x = x.view(x.size(0), -1) # flatten
dis = self(x)
plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels)
@@ -95,15 +104,14 @@ class LVQ21(GLVQ):
self.optimizer = torch.optim.SGD
class ImageGLVQ(GLVQ):
class ImageGLVQ(GLVQ, PrototypeImageModel):
"""GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.components.data.clamp_(0.0, 1.0)
pass
class SiameseGLVQ(GLVQ):
@@ -235,6 +243,7 @@ class GMLVQ(GLVQ):
def forward(self, x):
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
latent_x = self.omega_layer(x)
latent_protos = self.omega_layer(protos)
dis = squared_euclidean_distance(latent_x, latent_protos)
@@ -256,6 +265,16 @@ class GMLVQ(GLVQ):
return y_pred.numpy()
class ImageGMLVQ(GMLVQ, PrototypeImageModel):
"""GMLVQ for training on image data.
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
pass
class LVQMLN(GLVQ):
"""Learning Vector Quantization Multi-Layer Network.

View File

@@ -3,6 +3,7 @@ import os
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText
from prototorch.utils.celluloid import Camera
@@ -270,6 +271,7 @@ class Vis2DAbstract(pl.Callback):
border=1,
resolution=50,
show_protos=True,
show=True,
tensorboard=False,
show_last_only=False,
pause_time=0.1,
@@ -290,6 +292,7 @@ class Vis2DAbstract(pl.Callback):
self.border = border
self.resolution = resolution
self.show_protos = show_protos
self.show = show
self.tensorboard = tensorboard
self.show_last_only = show_last_only
self.pause_time = pause_time
@@ -352,10 +355,11 @@ class Vis2DAbstract(pl.Callback):
def log_and_display(self, trainer, pl_module):
if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module)
if not self.block:
plt.pause(self.pause_time)
else:
plt.show(block=True)
if self.show:
if not self.block:
plt.pause(self.pause_time)
else:
plt.show(block=True)
def on_train_end(self, trainer, pl_module):
plt.show()
@@ -458,3 +462,50 @@ class VisNG2D(Vis2DAbstract):
)
self.log_and_display(trainer, pl_module)
class VisImgComp(Vis2DAbstract):
def __init__(self,
*args,
random_data=0,
dataformats="CHW",
nrow=2,
**kwargs):
super().__init__(*args, **kwargs)
self.random_data = random_data
self.dataformats = dataformats
self.nrow = nrow
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
if self.show:
components = pl_module.components
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
self.log_and_display(trainer, pl_module)
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
components = pl_module.components
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
tb.add_image(
tag="Components",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats=self.dataformats,
)
if self.random_data:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data_img = self.x_train[ind]
grid = torchvision.utils.make_grid(data_img, nrow=self.nrow)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats=self.dataformats)