From fe36e5fad99689b6c625c10e914bee16c4bfdf84 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 21 Apr 2021 19:16:57 +0200 Subject: [PATCH] Add partial metric/hparam features [BROKEN STATE] --- examples/glvq_iris.py | 53 +++++++++++++++++++++++++++++++++------ prototorch/models/glvq.py | 17 +++++++++++-- setup.py | 2 +- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index f999a09..c8cd3b3 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -1,5 +1,7 @@ """GLVQ example using the Iris dataset.""" +import argparse + import numpy as np import pytorch_lightning as pl import torch @@ -60,6 +62,31 @@ class VisualizationCallback(pl.Callback): if __name__ == "__main__": + # Hyperparameters + parser = argparse.ArgumentParser() + parser.add_argument("--epochs", + type=int, + default=100, + help="Epochs to train.") + parser.add_argument("--lr", + type=float, + default=0.001, + help="Learning rate.") + parser.add_argument("--batch_size", + type=int, + default=256, + help="Batch size.") + parser.add_argument("--gpus", + type=int, + default=0, + help="Number of GPUs to use.") + parser.add_argument("--ppc", + type=int, + default=1, + help="Prototypes-Per-Class.") + args = parser.parse_args() + # https://pytorch-lightning.readthedocs.io/en/stable/common/hyperparameters.html + # Dataset x_train, y_train = load_iris(return_X_y=True) x_train = x_train[:, [0, 2]] @@ -72,10 +99,10 @@ if __name__ == "__main__": model = GLVQ( input_dim=x_train.shape[1], nclasses=3, - prototypes_per_class=3, + prototype_distribution=[2, 7, 5], prototype_initializer="stratified_mean", data=[x_train, y_train], - lr=0.1, + lr=0.01, ) # Model summary @@ -85,12 +112,24 @@ if __name__ == "__main__": vis = VisualizationCallback(x_train, y_train) # Setup trainer - trainer = pl.Trainer(max_epochs=1000, callbacks=[vis]) + trainer = pl.Trainer( + max_epochs=hparams.epochs, + auto_lr_find= + True, # finds learning rate automatically with `trainer.tune(model)` + callbacks=[ + vis, # comment this line out to disable the visualization + ], + ) + trainer.tune(model) # Training loop trainer.fit(model, train_loader) - # Visualization - protos = model.prototypes - plabels = model.prototype_labels - visualize(x_train, y_train, protos, plabels) + # Save the model manually (use `pl.callbacks.ModelCheckpoint` to automate) + ckpt = "glvq_iris.ckpt" + trainer.save_checkpoint(ckpt) + + # Load the checkpoint + new_model = GLVQ.load_from_checkpoint(checkpoint_path=ckpt) + + print(new_model) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index e04cd75..a8f4fcb 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -1,5 +1,6 @@ import pytorch_lightning as pl import torch +import torchmetrics from prototorch.functions.competitions import wtac from prototorch.functions.distances import euclidean_distance from prototorch.functions.initializers import get_initializer @@ -9,10 +10,11 @@ from prototorch.modules.prototypes import Prototypes1D class GLVQ(pl.LightningModule): """Generalized Learning Vector Quantization.""" - def __init__(self, lr=1e-3, **kwargs): + def __init__(self, hparams): super().__init__() - self.lr = lr + self.lr = hparams.lr self.proto_layer = Prototypes1D(**kwargs) + self.train_acc = torchmetrics.Accuracy() @property def prototypes(self): @@ -39,10 +41,21 @@ class GLVQ(pl.LightningModule): mu = glvq_loss(dis, y, prototype_labels=plabels) loss = mu.sum(dim=0) self.log("train_loss", loss) + with torch.no_grad(): + preds = wtac(dis, plabels) + # self.train_acc.update(preds.int(), y.int()) + self.train_acc(preds.int(), y.int()) # FloatTensors are assumed to be class probabilities + self.log("Training Accuracy", self.train_acc, on_step=False, on_epoch=True) return loss + # def training_epoch_end(self, outs): + # # Calling `self.train_acc.compute()` is + # # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)` + # self.log("train_acc_epoch", self.train_acc.compute()) + def predict(self, x): with torch.no_grad(): + # model.eval() # ?! d = self(x) plabels = self.proto_layer.prototype_labels y_pred = wtac(d, plabels) diff --git a/setup.py b/setup.py index 9316ede..9a5d527 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git" with open("README.md", "r") as fh: long_description = fh.read() -INSTALL_REQUIRES = ["prototorch", "pytorch_lightning"] +INSTALL_REQUIRES = ["prototorch", "pytorch_lightning", "torchmetrics"] EXAMPLES = ["matplotlib", "scikit-learn"] TESTS = ["pytest"] ALL = EXAMPLES + TESTS