Stop passing component initializers as hparams
Pass the component initializer as an hparam slows down the script very much. The API has now been changed to pass it as a kwarg to the models instead. The example scripts have also been updated to reflect the new changes. Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also been added.
This commit is contained in:
@@ -21,12 +21,13 @@ if __name__ == "__main__":
|
||||
prototypes_per_class = 2
|
||||
hparams = dict(
|
||||
distribution=(nclasses, prototypes_per_class),
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GLVQ(hparams, optimizer=torch.optim.Adam)
|
||||
model = pt.models.GLVQ(hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.SMI(train_ds))
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||
|
@@ -29,14 +29,15 @@ if __name__ == "__main__":
|
||||
prototypes_per_class = 20
|
||||
hparams = dict(
|
||||
distribution=(nclasses, prototypes_per_class),
|
||||
prototype_initializer=pt.components.SSI(train_ds, noise=1e-1),
|
||||
transfer_function="sigmoid_beta",
|
||||
transfer_beta=10.0,
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GLVQ(hparams)
|
||||
model = pt.models.GLVQ(hparams,
|
||||
prototype_initializer=pt.components.SSI(train_ds,
|
||||
noise=1e-1))
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
|
||||
|
@@ -21,12 +21,12 @@ if __name__ == "__main__":
|
||||
distribution=(nclasses, prototypes_per_class),
|
||||
input_dim=x_train.shape[1],
|
||||
latent_dim=x_train.shape[1],
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GMLVQ(hparams)
|
||||
model = pt.models.GMLVQ(hparams,
|
||||
prototype_initializer=pt.components.SMI(train_ds))
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(max_epochs=100)
|
||||
|
68
examples/gmlvq_mnist.py
Normal file
68
examples/gmlvq_mnist.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""GMLVQ example using the MNIST dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
train_ds = MNIST(
|
||||
"~/datasets",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
]),
|
||||
)
|
||||
test_ds = MNIST(
|
||||
"~/datasets",
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
]),
|
||||
)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=256)
|
||||
test_loader = torch.utils.data.DataLoader(test_ds,
|
||||
num_workers=0,
|
||||
batch_size=256)
|
||||
|
||||
# Hyperparameters
|
||||
nclasses = 10
|
||||
prototypes_per_class = 2
|
||||
hparams = dict(
|
||||
input_dim=28 * 28,
|
||||
latent_dim=28 * 28,
|
||||
distribution=(nclasses, prototypes_per_class),
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.ImageGMLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisImgComp(data=train_ds,
|
||||
nrow=5,
|
||||
show=False,
|
||||
tensorboard=True)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=50,
|
||||
callbacks=[vis],
|
||||
# overfit_batches=1,
|
||||
# fast_dev_run=3,
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -23,12 +23,12 @@ if __name__ == "__main__":
|
||||
distribution=(nclasses, prototypes_per_class),
|
||||
input_dim=100,
|
||||
latent_dim=2,
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
lr=0.001,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GMLVQ(hparams)
|
||||
model = pt.models.GMLVQ(hparams,
|
||||
prototype_initializer=pt.components.SMI(train_ds))
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
|
||||
|
@@ -37,7 +37,6 @@ if __name__ == "__main__":
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[1, 2, 3],
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
@@ -45,6 +44,7 @@ if __name__ == "__main__":
|
||||
# Initialize the model
|
||||
model = pt.models.SiameseGLVQ(
|
||||
hparams,
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
backbone_module=Backbone,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user