diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py index 718a129..dee0a84 100644 --- a/examples/gmlvq_iris.py +++ b/examples/gmlvq_iris.py @@ -71,3 +71,5 @@ if __name__ == "__main__": # Training loop trainer.fit(model, train_loader) + + torch.save(model, "iris.pth") \ No newline at end of file diff --git a/examples/grlvq_iris.py b/examples/grlvq_iris.py new file mode 100644 index 0000000..2ede559 --- /dev/null +++ b/examples/grlvq_iris.py @@ -0,0 +1,74 @@ +"""GMLVQ example using the Iris dataset.""" + +import argparse +import warnings + +import prototorch as pt +import pytorch_lightning as pl +import torch +from prototorch.models import GRLVQ, VisSiameseGLVQ2D +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from torch.optim.lr_scheduler import ExponentialLR +from torch.utils.data import DataLoader + +warnings.filterwarnings("ignore", category=PossibleUserWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +if __name__ == "__main__": + + # Reproducibility + seed_everything(seed=4) + + # Command-line arguments + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # Dataset + train_ds = pt.datasets.Iris([0, 1]) + + # Dataloaders + train_loader = DataLoader(train_ds, batch_size=64) + + # Hyperparameters + hparams = dict( + input_dim=2, + distribution={ + "num_classes": 3, + "per_class": 2 + }, + proto_lr=0.01, + bb_lr=0.01, + ) + + # Initialize the model + model = GRLVQ( + hparams, + optimizer=torch.optim.Adam, + prototypes_initializer=pt.initializers.SMCI(train_ds), + lr_scheduler=ExponentialLR, + lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), + ) + + # Compute intermediate input and output sizes + model.example_input_array = torch.zeros(4, 2) + + # Callbacks + vis = VisSiameseGLVQ2D(data=train_ds) + + # Setup trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=[ + vis, + ], + max_epochs=5, + log_every_n_steps=1, + detect_anomaly=True, + ) + + # Training loop + trainer.fit(model, train_loader) + + torch.save(model, "iris.pth") \ No newline at end of file diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index f990145..42b0bb3 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -71,7 +71,7 @@ class PrototypeModel(ProtoTorchBolt): super().__init__(hparams, **kwargs) distance_fn = kwargs.get("distance_fn", euclidean_distance) - self.distance_layer = LambdaLayer(distance_fn) + self.distance_layer = LambdaLayer(distance_fn, name="distance_fn") @property def num_prototypes(self): diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index cab60d5..d376c9a 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -209,9 +209,12 @@ class GRLVQ(SiameseGLVQ): self.register_parameter("_relevances", Parameter(relevances)) # Override the backbone - self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances), + self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling") + def _apply_relevances(self, x): + return x @ torch.diag(self._relevances) + @property def relevance_profile(self): return self._relevances.detach().cpu() @@ -271,9 +274,7 @@ class GMLVQ(GLVQ): omega = omega_initializer.generate(self.hparams["input_dim"], self.hparams["latent_dim"]) self.register_parameter("_omega", Parameter(omega)) - self.backbone = LambdaLayer(lambda x: x @ self._omega, - name="omega matrix") - + @property def omega_matrix(self): return self._omega.detach().cpu()