Stop passing component initializers as hparams

Pass the component initializer as an hparam slows down the script very much. The
API has now been changed to pass it as a kwarg to the models instead.

The example scripts have also been updated to reflect the new changes.

Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also
been added.
This commit is contained in:
Jensun Ravichandran
2021-05-12 16:36:22 +02:00
parent 1498c4bde5
commit ca39aa00d5
11 changed files with 172 additions and 21 deletions

View File

@@ -21,12 +21,13 @@ if __name__ == "__main__":
prototypes_per_class = 2
hparams = dict(
distribution=(nclasses, prototypes_per_class),
prototype_initializer=pt.components.SMI(train_ds),
lr=0.01,
)
# Initialize the model
model = pt.models.GLVQ(hparams, optimizer=torch.optim.Adam)
model = pt.models.GLVQ(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SMI(train_ds))
# Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))

View File

@@ -29,14 +29,15 @@ if __name__ == "__main__":
prototypes_per_class = 20
hparams = dict(
distribution=(nclasses, prototypes_per_class),
prototype_initializer=pt.components.SSI(train_ds, noise=1e-1),
transfer_function="sigmoid_beta",
transfer_beta=10.0,
lr=0.01,
)
# Initialize the model
model = pt.models.GLVQ(hparams)
model = pt.models.GLVQ(hparams,
prototype_initializer=pt.components.SSI(train_ds,
noise=1e-1))
# Callbacks
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)

View File

@@ -21,12 +21,12 @@ if __name__ == "__main__":
distribution=(nclasses, prototypes_per_class),
input_dim=x_train.shape[1],
latent_dim=x_train.shape[1],
prototype_initializer=pt.components.SMI(train_ds),
lr=0.01,
)
# Initialize the model
model = pt.models.GMLVQ(hparams)
model = pt.models.GMLVQ(hparams,
prototype_initializer=pt.components.SMI(train_ds))
# Setup trainer
trainer = pl.Trainer(max_epochs=100)

68
examples/gmlvq_mnist.py Normal file
View File

@@ -0,0 +1,68 @@
"""GMLVQ example using the MNIST dataset."""
import prototorch as pt
import pytorch_lightning as pl
import torch
from torchvision import transforms
from torchvision.datasets import MNIST
if __name__ == "__main__":
# Dataset
train_ds = MNIST(
"~/datasets",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
test_ds = MNIST(
"~/datasets",
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=256)
test_loader = torch.utils.data.DataLoader(test_ds,
num_workers=0,
batch_size=256)
# Hyperparameters
nclasses = 10
prototypes_per_class = 2
hparams = dict(
input_dim=28 * 28,
latent_dim=28 * 28,
distribution=(nclasses, prototypes_per_class),
lr=0.01,
)
# Initialize the model
model = pt.models.ImageGMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SMI(train_ds),
)
# Callbacks
vis = pt.models.VisImgComp(data=train_ds,
nrow=5,
show=False,
tensorboard=True)
# Setup trainer
trainer = pl.Trainer(
max_epochs=50,
callbacks=[vis],
# overfit_batches=1,
# fast_dev_run=3,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -23,12 +23,12 @@ if __name__ == "__main__":
distribution=(nclasses, prototypes_per_class),
input_dim=100,
latent_dim=2,
prototype_initializer=pt.components.SMI(train_ds),
lr=0.001,
)
# Initialize the model
model = pt.models.GMLVQ(hparams)
model = pt.models.GMLVQ(hparams,
prototype_initializer=pt.components.SMI(train_ds))
# Callbacks
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)

View File

@@ -37,7 +37,6 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
distribution=[1, 2, 3],
prototype_initializer=pt.components.SMI(train_ds),
proto_lr=0.01,
bb_lr=0.01,
)
@@ -45,6 +44,7 @@ if __name__ == "__main__":
# Initialize the model
model = pt.models.SiameseGLVQ(
hparams,
prototype_initializer=pt.components.SMI(train_ds),
backbone_module=Backbone,
)