Fix: saving GMLVQ and GRLVQ fixed

This commit is contained in:
Alexander Engelsberger 2023-03-09 15:50:13 +01:00
parent 87fa3f0729
commit 46dfb82371
No known key found for this signature in database
4 changed files with 82 additions and 5 deletions

View File

@ -71,3 +71,5 @@ if __name__ == "__main__":
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
torch.save(model, "iris.pth")

74
examples/grlvq_iris.py Normal file
View File

@ -0,0 +1,74 @@
"""GMLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import GRLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris([0, 1])
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
input_dim=2,
distribution={
"num_classes": 3,
"per_class": 2
},
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = GRLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisSiameseGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
max_epochs=5,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)
torch.save(model, "iris.pth")

View File

@ -71,7 +71,7 @@ class PrototypeModel(ProtoTorchBolt):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance) distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn) self.distance_layer = LambdaLayer(distance_fn, name="distance_fn")
@property @property
def num_prototypes(self): def num_prototypes(self):

View File

@ -209,9 +209,12 @@ class GRLVQ(SiameseGLVQ):
self.register_parameter("_relevances", Parameter(relevances)) self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone # Override the backbone
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances), self.backbone = LambdaLayer(self._apply_relevances,
name="relevance scaling") name="relevance scaling")
def _apply_relevances(self, x):
return x @ torch.diag(self._relevances)
@property @property
def relevance_profile(self): def relevance_profile(self):
return self._relevances.detach().cpu() return self._relevances.detach().cpu()
@ -271,8 +274,6 @@ class GMLVQ(GLVQ):
omega = omega_initializer.generate(self.hparams["input_dim"], omega = omega_initializer.generate(self.hparams["input_dim"],
self.hparams["latent_dim"]) self.hparams["latent_dim"])
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega,
name="omega matrix")
@property @property
def omega_matrix(self): def omega_matrix(self):