Stop passing component initializers as hparams

Pass the component initializer as an hparam slows down the script very much. The
API has now been changed to pass it as a kwarg to the models instead.

The example scripts have also been updated to reflect the new changes.

Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also
been added.
This commit is contained in:
Jensun Ravichandran 2021-05-12 16:36:22 +02:00
parent 1498c4bde5
commit ca39aa00d5
11 changed files with 172 additions and 21 deletions

View File

@ -51,7 +51,7 @@ To assist in the development process, you may also find it useful to install
## Available models ## Available models
- K-Nearest Neighbors (KNN) - k-Nearest Neighbors (KNN)
- Learning Vector Quantization 1 (LVQ1) - Learning Vector Quantization 1 (LVQ1)
- Generalized Learning Vector Quantization (GLVQ) - Generalized Learning Vector Quantization (GLVQ)
- Generalized Relevance Learning Vector Quantization (GRLVQ) - Generalized Relevance Learning Vector Quantization (GRLVQ)
@ -68,6 +68,7 @@ To assist in the development process, you may also find it useful to install
## Planned models ## Planned models
- Median-LVQ
- Local-Matrix GMLVQ - Local-Matrix GMLVQ
- Generalized Tangent Learning Vector Quantization (GTLVQ) - Generalized Tangent Learning Vector Quantization (GTLVQ)
- Robust Soft Learning Vector Quantization (RSLVQ) - Robust Soft Learning Vector Quantization (RSLVQ)

View File

@ -21,12 +21,13 @@ if __name__ == "__main__":
prototypes_per_class = 2 prototypes_per_class = 2
hparams = dict( hparams = dict(
distribution=(nclasses, prototypes_per_class), distribution=(nclasses, prototypes_per_class),
prototype_initializer=pt.components.SMI(train_ds),
lr=0.01, lr=0.01,
) )
# Initialize the model # Initialize the model
model = pt.models.GLVQ(hparams, optimizer=torch.optim.Adam) model = pt.models.GLVQ(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SMI(train_ds))
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) vis = pt.models.VisGLVQ2D(data=(x_train, y_train))

View File

@ -29,14 +29,15 @@ if __name__ == "__main__":
prototypes_per_class = 20 prototypes_per_class = 20
hparams = dict( hparams = dict(
distribution=(nclasses, prototypes_per_class), distribution=(nclasses, prototypes_per_class),
prototype_initializer=pt.components.SSI(train_ds, noise=1e-1),
transfer_function="sigmoid_beta", transfer_function="sigmoid_beta",
transfer_beta=10.0, transfer_beta=10.0,
lr=0.01, lr=0.01,
) )
# Initialize the model # Initialize the model
model = pt.models.GLVQ(hparams) model = pt.models.GLVQ(hparams,
prototype_initializer=pt.components.SSI(train_ds,
noise=1e-1))
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True) vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)

View File

@ -21,12 +21,12 @@ if __name__ == "__main__":
distribution=(nclasses, prototypes_per_class), distribution=(nclasses, prototypes_per_class),
input_dim=x_train.shape[1], input_dim=x_train.shape[1],
latent_dim=x_train.shape[1], latent_dim=x_train.shape[1],
prototype_initializer=pt.components.SMI(train_ds),
lr=0.01, lr=0.01,
) )
# Initialize the model # Initialize the model
model = pt.models.GMLVQ(hparams) model = pt.models.GMLVQ(hparams,
prototype_initializer=pt.components.SMI(train_ds))
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=100) trainer = pl.Trainer(max_epochs=100)

68
examples/gmlvq_mnist.py Normal file
View File

@ -0,0 +1,68 @@
"""GMLVQ example using the MNIST dataset."""
import prototorch as pt
import pytorch_lightning as pl
import torch
from torchvision import transforms
from torchvision.datasets import MNIST
if __name__ == "__main__":
# Dataset
train_ds = MNIST(
"~/datasets",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
test_ds = MNIST(
"~/datasets",
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=256)
test_loader = torch.utils.data.DataLoader(test_ds,
num_workers=0,
batch_size=256)
# Hyperparameters
nclasses = 10
prototypes_per_class = 2
hparams = dict(
input_dim=28 * 28,
latent_dim=28 * 28,
distribution=(nclasses, prototypes_per_class),
lr=0.01,
)
# Initialize the model
model = pt.models.ImageGMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SMI(train_ds),
)
# Callbacks
vis = pt.models.VisImgComp(data=train_ds,
nrow=5,
show=False,
tensorboard=True)
# Setup trainer
trainer = pl.Trainer(
max_epochs=50,
callbacks=[vis],
# overfit_batches=1,
# fast_dev_run=3,
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -23,12 +23,12 @@ if __name__ == "__main__":
distribution=(nclasses, prototypes_per_class), distribution=(nclasses, prototypes_per_class),
input_dim=100, input_dim=100,
latent_dim=2, latent_dim=2,
prototype_initializer=pt.components.SMI(train_ds),
lr=0.001, lr=0.001,
) )
# Initialize the model # Initialize the model
model = pt.models.GMLVQ(hparams) model = pt.models.GMLVQ(hparams,
prototype_initializer=pt.components.SMI(train_ds))
# Callbacks # Callbacks
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1) vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)

View File

@ -37,7 +37,6 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[1, 2, 3], distribution=[1, 2, 3],
prototype_initializer=pt.components.SMI(train_ds),
proto_lr=0.01, proto_lr=0.01,
bb_lr=0.01, bb_lr=0.01,
) )
@ -45,6 +44,7 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = pt.models.SiameseGLVQ( model = pt.models.SiameseGLVQ(
hparams, hparams,
prototype_initializer=pt.components.SMI(train_ds),
backbone_module=Backbone, backbone_module=Backbone,
) )

View File

@ -2,7 +2,7 @@ from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC from .cbc import CBC
from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ, from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ,
SiameseGLVQ) ImageGMLVQ, SiameseGLVQ)
from .knn import KNN from .knn import KNN
from .neural_gas import NeuralGas from .neural_gas import NeuralGas
from .vis import * from .vis import *

View File

@ -8,6 +8,11 @@ class AbstractPrototypeModel(pl.LightningModule):
def prototypes(self): def prototypes(self):
return self.proto_layer.components.detach().cpu() return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
def configure_optimizers(self): def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr) optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer, scheduler = ExponentialLR(optimizer,
@ -19,3 +24,8 @@ class AbstractPrototypeModel(pl.LightningModule):
"interval": "step", "interval": "step",
} # called after each training step } # called after each training step
return [optimizer], [sch] return [optimizer], [sch]
class PrototypeImageModel(pl.LightningModule):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.components.data.clamp_(0.0, 1.0)

View File

@ -5,9 +5,18 @@ from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance, from prototorch.functions.distances import (euclidean_distance, omega_distance,
squared_euclidean_distance) squared_euclidean_distance)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules.mappings import OmegaMapping
from .abstract import AbstractPrototypeModel from .abstract import AbstractPrototypeModel, PrototypeImageModel
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
from .abstract import AbstractPrototypeModel, PrototypeImageModel
class GLVQ(AbstractPrototypeModel): class GLVQ(AbstractPrototypeModel):
@ -18,6 +27,7 @@ class GLVQ(AbstractPrototypeModel):
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam) self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
prototype_initializer = kwargs.get("prototype_initializer", None)
# Default Values # Default Values
self.hparams.setdefault("distance", euclidean_distance) self.hparams.setdefault("distance", euclidean_distance)
@ -26,7 +36,7 @@ class GLVQ(AbstractPrototypeModel):
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution, distribution=self.hparams.distribution,
initializer=self.hparams.prototype_initializer) initializer=prototype_initializer)
self.transfer_function = get_activation(self.hparams.transfer_function) self.transfer_function = get_activation(self.hparams.transfer_function)
self.train_acc = torchmetrics.Accuracy() self.train_acc = torchmetrics.Accuracy()
@ -44,7 +54,6 @@ class GLVQ(AbstractPrototypeModel):
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
x, y = train_batch x, y = train_batch
x = x.view(x.size(0), -1) # flatten
dis = self(x) dis = self(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels) mu = self.loss(dis, y, prototype_labels=plabels)
@ -95,15 +104,14 @@ class LVQ21(GLVQ):
self.optimizer = torch.optim.SGD self.optimizer = torch.optim.SGD
class ImageGLVQ(GLVQ): class ImageGLVQ(GLVQ, PrototypeImageModel):
"""GLVQ for training on image data. """GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by clamping GLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates. after updates.
""" """
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): pass
self.proto_layer.components.data.clamp_(0.0, 1.0)
class SiameseGLVQ(GLVQ): class SiameseGLVQ(GLVQ):
@ -235,6 +243,7 @@ class GMLVQ(GLVQ):
def forward(self, x): def forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
latent_x = self.omega_layer(x) latent_x = self.omega_layer(x)
latent_protos = self.omega_layer(protos) latent_protos = self.omega_layer(protos)
dis = squared_euclidean_distance(latent_x, latent_protos) dis = squared_euclidean_distance(latent_x, latent_protos)
@ -256,6 +265,16 @@ class GMLVQ(GLVQ):
return y_pred.numpy() return y_pred.numpy()
class ImageGMLVQ(GMLVQ, PrototypeImageModel):
"""GMLVQ for training on image data.
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
pass
class LVQMLN(GLVQ): class LVQMLN(GLVQ):
"""Learning Vector Quantization Multi-Layer Network. """Learning Vector Quantization Multi-Layer Network.

View File

@ -3,6 +3,7 @@ import os
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchvision
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText from matplotlib.offsetbox import AnchoredText
from prototorch.utils.celluloid import Camera from prototorch.utils.celluloid import Camera
@ -270,6 +271,7 @@ class Vis2DAbstract(pl.Callback):
border=1, border=1,
resolution=50, resolution=50,
show_protos=True, show_protos=True,
show=True,
tensorboard=False, tensorboard=False,
show_last_only=False, show_last_only=False,
pause_time=0.1, pause_time=0.1,
@ -290,6 +292,7 @@ class Vis2DAbstract(pl.Callback):
self.border = border self.border = border
self.resolution = resolution self.resolution = resolution
self.show_protos = show_protos self.show_protos = show_protos
self.show = show
self.tensorboard = tensorboard self.tensorboard = tensorboard
self.show_last_only = show_last_only self.show_last_only = show_last_only
self.pause_time = pause_time self.pause_time = pause_time
@ -352,6 +355,7 @@ class Vis2DAbstract(pl.Callback):
def log_and_display(self, trainer, pl_module): def log_and_display(self, trainer, pl_module):
if self.tensorboard: if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module) self.add_to_tensorboard(trainer, pl_module)
if self.show:
if not self.block: if not self.block:
plt.pause(self.pause_time) plt.pause(self.pause_time)
else: else:
@ -458,3 +462,50 @@ class VisNG2D(Vis2DAbstract):
) )
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)
class VisImgComp(Vis2DAbstract):
def __init__(self,
*args,
random_data=0,
dataformats="CHW",
nrow=2,
**kwargs):
super().__init__(*args, **kwargs)
self.random_data = random_data
self.dataformats = dataformats
self.nrow = nrow
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
if self.show:
components = pl_module.components
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
self.log_and_display(trainer, pl_module)
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
components = pl_module.components
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
tb.add_image(
tag="Components",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats=self.dataformats,
)
if self.random_data:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data_img = self.x_train[ind]
grid = torchvision.utils.make_grid(data_img, nrow=self.nrow)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats=self.dataformats)