Stop passing component initializers as hparams
Pass the component initializer as an hparam slows down the script very much. The API has now been changed to pass it as a kwarg to the models instead. The example scripts have also been updated to reflect the new changes. Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also been added.
This commit is contained in:
parent
1498c4bde5
commit
ca39aa00d5
@ -51,7 +51,7 @@ To assist in the development process, you may also find it useful to install
|
|||||||
|
|
||||||
## Available models
|
## Available models
|
||||||
|
|
||||||
- K-Nearest Neighbors (KNN)
|
- k-Nearest Neighbors (KNN)
|
||||||
- Learning Vector Quantization 1 (LVQ1)
|
- Learning Vector Quantization 1 (LVQ1)
|
||||||
- Generalized Learning Vector Quantization (GLVQ)
|
- Generalized Learning Vector Quantization (GLVQ)
|
||||||
- Generalized Relevance Learning Vector Quantization (GRLVQ)
|
- Generalized Relevance Learning Vector Quantization (GRLVQ)
|
||||||
@ -68,6 +68,7 @@ To assist in the development process, you may also find it useful to install
|
|||||||
|
|
||||||
## Planned models
|
## Planned models
|
||||||
|
|
||||||
|
- Median-LVQ
|
||||||
- Local-Matrix GMLVQ
|
- Local-Matrix GMLVQ
|
||||||
- Generalized Tangent Learning Vector Quantization (GTLVQ)
|
- Generalized Tangent Learning Vector Quantization (GTLVQ)
|
||||||
- Robust Soft Learning Vector Quantization (RSLVQ)
|
- Robust Soft Learning Vector Quantization (RSLVQ)
|
||||||
|
@ -21,12 +21,13 @@ if __name__ == "__main__":
|
|||||||
prototypes_per_class = 2
|
prototypes_per_class = 2
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=(nclasses, prototypes_per_class),
|
distribution=(nclasses, prototypes_per_class),
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GLVQ(hparams, optimizer=torch.optim.Adam)
|
model = pt.models.GLVQ(hparams,
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
prototype_initializer=pt.components.SMI(train_ds))
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||||
|
@ -29,14 +29,15 @@ if __name__ == "__main__":
|
|||||||
prototypes_per_class = 20
|
prototypes_per_class = 20
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=(nclasses, prototypes_per_class),
|
distribution=(nclasses, prototypes_per_class),
|
||||||
prototype_initializer=pt.components.SSI(train_ds, noise=1e-1),
|
|
||||||
transfer_function="sigmoid_beta",
|
transfer_function="sigmoid_beta",
|
||||||
transfer_beta=10.0,
|
transfer_beta=10.0,
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GLVQ(hparams)
|
model = pt.models.GLVQ(hparams,
|
||||||
|
prototype_initializer=pt.components.SSI(train_ds,
|
||||||
|
noise=1e-1))
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
|
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
|
||||||
|
@ -21,12 +21,12 @@ if __name__ == "__main__":
|
|||||||
distribution=(nclasses, prototypes_per_class),
|
distribution=(nclasses, prototypes_per_class),
|
||||||
input_dim=x_train.shape[1],
|
input_dim=x_train.shape[1],
|
||||||
latent_dim=x_train.shape[1],
|
latent_dim=x_train.shape[1],
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GMLVQ(hparams)
|
model = pt.models.GMLVQ(hparams,
|
||||||
|
prototype_initializer=pt.components.SMI(train_ds))
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(max_epochs=100)
|
trainer = pl.Trainer(max_epochs=100)
|
||||||
|
68
examples/gmlvq_mnist.py
Normal file
68
examples/gmlvq_mnist.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
"""GMLVQ example using the MNIST dataset."""
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Dataset
|
||||||
|
train_ds = MNIST(
|
||||||
|
"~/datasets",
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
transform=transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
test_ds = MNIST(
|
||||||
|
"~/datasets",
|
||||||
|
train=False,
|
||||||
|
download=True,
|
||||||
|
transform=transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=256)
|
||||||
|
test_loader = torch.utils.data.DataLoader(test_ds,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=256)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
nclasses = 10
|
||||||
|
prototypes_per_class = 2
|
||||||
|
hparams = dict(
|
||||||
|
input_dim=28 * 28,
|
||||||
|
latent_dim=28 * 28,
|
||||||
|
distribution=(nclasses, prototypes_per_class),
|
||||||
|
lr=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.ImageGMLVQ(
|
||||||
|
hparams,
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
prototype_initializer=pt.components.SMI(train_ds),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.VisImgComp(data=train_ds,
|
||||||
|
nrow=5,
|
||||||
|
show=False,
|
||||||
|
tensorboard=True)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=50,
|
||||||
|
callbacks=[vis],
|
||||||
|
# overfit_batches=1,
|
||||||
|
# fast_dev_run=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
@ -23,12 +23,12 @@ if __name__ == "__main__":
|
|||||||
distribution=(nclasses, prototypes_per_class),
|
distribution=(nclasses, prototypes_per_class),
|
||||||
input_dim=100,
|
input_dim=100,
|
||||||
latent_dim=2,
|
latent_dim=2,
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
|
||||||
lr=0.001,
|
lr=0.001,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GMLVQ(hparams)
|
model = pt.models.GMLVQ(hparams,
|
||||||
|
prototype_initializer=pt.components.SMI(train_ds))
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
|
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
|
||||||
|
@ -37,7 +37,6 @@ if __name__ == "__main__":
|
|||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[1, 2, 3],
|
distribution=[1, 2, 3],
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
|
||||||
proto_lr=0.01,
|
proto_lr=0.01,
|
||||||
bb_lr=0.01,
|
bb_lr=0.01,
|
||||||
)
|
)
|
||||||
@ -45,6 +44,7 @@ if __name__ == "__main__":
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.SiameseGLVQ(
|
model = pt.models.SiameseGLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
|
prototype_initializer=pt.components.SMI(train_ds),
|
||||||
backbone_module=Backbone,
|
backbone_module=Backbone,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from importlib.metadata import PackageNotFoundError, version
|
|||||||
|
|
||||||
from .cbc import CBC
|
from .cbc import CBC
|
||||||
from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ,
|
from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ,
|
||||||
SiameseGLVQ)
|
ImageGMLVQ, SiameseGLVQ)
|
||||||
from .knn import KNN
|
from .knn import KNN
|
||||||
from .neural_gas import NeuralGas
|
from .neural_gas import NeuralGas
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
@ -8,6 +8,11 @@ class AbstractPrototypeModel(pl.LightningModule):
|
|||||||
def prototypes(self):
|
def prototypes(self):
|
||||||
return self.proto_layer.components.detach().cpu()
|
return self.proto_layer.components.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def components(self):
|
||||||
|
"""Only an alias for the prototypes."""
|
||||||
|
return self.prototypes
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
||||||
scheduler = ExponentialLR(optimizer,
|
scheduler = ExponentialLR(optimizer,
|
||||||
@ -19,3 +24,8 @@ class AbstractPrototypeModel(pl.LightningModule):
|
|||||||
"interval": "step",
|
"interval": "step",
|
||||||
} # called after each training step
|
} # called after each training step
|
||||||
return [optimizer], [sch]
|
return [optimizer], [sch]
|
||||||
|
|
||||||
|
|
||||||
|
class PrototypeImageModel(pl.LightningModule):
|
||||||
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||||
|
@ -5,9 +5,18 @@ from prototorch.functions.activations import get_activation
|
|||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
||||||
squared_euclidean_distance)
|
squared_euclidean_distance)
|
||||||
|
from prototorch.functions.helper import get_flat
|
||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
|
from prototorch.modules.mappings import OmegaMapping
|
||||||
|
|
||||||
from .abstract import AbstractPrototypeModel
|
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||||
|
|
||||||
|
|
||||||
|
class GLVQ(AbstractPrototypeModel):
|
||||||
|
"""Generalized Learning Vector Quantization."""
|
||||||
|
|
||||||
|
|
||||||
|
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||||
|
|
||||||
|
|
||||||
class GLVQ(AbstractPrototypeModel):
|
class GLVQ(AbstractPrototypeModel):
|
||||||
@ -18,6 +27,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
self.save_hyperparameters(hparams)
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
|
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||||
|
|
||||||
# Default Values
|
# Default Values
|
||||||
self.hparams.setdefault("distance", euclidean_distance)
|
self.hparams.setdefault("distance", euclidean_distance)
|
||||||
@ -26,7 +36,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
distribution=self.hparams.distribution,
|
distribution=self.hparams.distribution,
|
||||||
initializer=self.hparams.prototype_initializer)
|
initializer=prototype_initializer)
|
||||||
|
|
||||||
self.transfer_function = get_activation(self.hparams.transfer_function)
|
self.transfer_function = get_activation(self.hparams.transfer_function)
|
||||||
self.train_acc = torchmetrics.Accuracy()
|
self.train_acc = torchmetrics.Accuracy()
|
||||||
@ -44,7 +54,6 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
x = x.view(x.size(0), -1) # flatten
|
|
||||||
dis = self(x)
|
dis = self(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||||
@ -95,15 +104,14 @@ class LVQ21(GLVQ):
|
|||||||
self.optimizer = torch.optim.SGD
|
self.optimizer = torch.optim.SGD
|
||||||
|
|
||||||
|
|
||||||
class ImageGLVQ(GLVQ):
|
class ImageGLVQ(GLVQ, PrototypeImageModel):
|
||||||
"""GLVQ for training on image data.
|
"""GLVQ for training on image data.
|
||||||
|
|
||||||
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||||
after updates.
|
after updates.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
pass
|
||||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
|
||||||
|
|
||||||
|
|
||||||
class SiameseGLVQ(GLVQ):
|
class SiameseGLVQ(GLVQ):
|
||||||
@ -235,6 +243,7 @@ class GMLVQ(GLVQ):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
|
x, protos = get_flat(x, protos)
|
||||||
latent_x = self.omega_layer(x)
|
latent_x = self.omega_layer(x)
|
||||||
latent_protos = self.omega_layer(protos)
|
latent_protos = self.omega_layer(protos)
|
||||||
dis = squared_euclidean_distance(latent_x, latent_protos)
|
dis = squared_euclidean_distance(latent_x, latent_protos)
|
||||||
@ -256,6 +265,16 @@ class GMLVQ(GLVQ):
|
|||||||
return y_pred.numpy()
|
return y_pred.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGMLVQ(GMLVQ, PrototypeImageModel):
|
||||||
|
"""GMLVQ for training on image data.
|
||||||
|
|
||||||
|
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||||
|
after updates.
|
||||||
|
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LVQMLN(GLVQ):
|
class LVQMLN(GLVQ):
|
||||||
"""Learning Vector Quantization Multi-Layer Network.
|
"""Learning Vector Quantization Multi-Layer Network.
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from matplotlib.offsetbox import AnchoredText
|
from matplotlib.offsetbox import AnchoredText
|
||||||
from prototorch.utils.celluloid import Camera
|
from prototorch.utils.celluloid import Camera
|
||||||
@ -270,6 +271,7 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
border=1,
|
border=1,
|
||||||
resolution=50,
|
resolution=50,
|
||||||
show_protos=True,
|
show_protos=True,
|
||||||
|
show=True,
|
||||||
tensorboard=False,
|
tensorboard=False,
|
||||||
show_last_only=False,
|
show_last_only=False,
|
||||||
pause_time=0.1,
|
pause_time=0.1,
|
||||||
@ -290,6 +292,7 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
self.border = border
|
self.border = border
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
self.show_protos = show_protos
|
self.show_protos = show_protos
|
||||||
|
self.show = show
|
||||||
self.tensorboard = tensorboard
|
self.tensorboard = tensorboard
|
||||||
self.show_last_only = show_last_only
|
self.show_last_only = show_last_only
|
||||||
self.pause_time = pause_time
|
self.pause_time = pause_time
|
||||||
@ -352,10 +355,11 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
def log_and_display(self, trainer, pl_module):
|
def log_and_display(self, trainer, pl_module):
|
||||||
if self.tensorboard:
|
if self.tensorboard:
|
||||||
self.add_to_tensorboard(trainer, pl_module)
|
self.add_to_tensorboard(trainer, pl_module)
|
||||||
if not self.block:
|
if self.show:
|
||||||
plt.pause(self.pause_time)
|
if not self.block:
|
||||||
else:
|
plt.pause(self.pause_time)
|
||||||
plt.show(block=True)
|
else:
|
||||||
|
plt.show(block=True)
|
||||||
|
|
||||||
def on_train_end(self, trainer, pl_module):
|
def on_train_end(self, trainer, pl_module):
|
||||||
plt.show()
|
plt.show()
|
||||||
@ -458,3 +462,50 @@ class VisNG2D(Vis2DAbstract):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
|
|
||||||
|
class VisImgComp(Vis2DAbstract):
|
||||||
|
def __init__(self,
|
||||||
|
*args,
|
||||||
|
random_data=0,
|
||||||
|
dataformats="CHW",
|
||||||
|
nrow=2,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.random_data = random_data
|
||||||
|
self.dataformats = dataformats
|
||||||
|
self.nrow = nrow
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self.show:
|
||||||
|
components = pl_module.components
|
||||||
|
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
|
||||||
|
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
||||||
|
|
||||||
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
|
def add_to_tensorboard(self, trainer, pl_module):
|
||||||
|
tb = pl_module.logger.experiment
|
||||||
|
|
||||||
|
components = pl_module.components
|
||||||
|
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
|
||||||
|
tb.add_image(
|
||||||
|
tag="Components",
|
||||||
|
img_tensor=grid,
|
||||||
|
global_step=trainer.current_epoch,
|
||||||
|
dataformats=self.dataformats,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.random_data:
|
||||||
|
ind = np.random.choice(len(self.x_train),
|
||||||
|
size=self.random_data,
|
||||||
|
replace=False)
|
||||||
|
data_img = self.x_train[ind]
|
||||||
|
grid = torchvision.utils.make_grid(data_img, nrow=self.nrow)
|
||||||
|
tb.add_image(tag="Data",
|
||||||
|
img_tensor=grid,
|
||||||
|
global_step=trainer.current_epoch,
|
||||||
|
dataformats=self.dataformats)
|
||||||
|
Loading…
Reference in New Issue
Block a user