From b031382072681718849a29fb0e32909bf5c13022 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 7 Jun 2021 18:35:08 +0200 Subject: [PATCH] Update NG --- examples/gng_iris.py | 6 ++---- examples/ng_iris.py | 6 +++++- prototorch/models/unsupervised.py | 1 - 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/gng_iris.py b/examples/gng_iris.py index 97f0b22..dd25b16 100644 --- a/examples/gng_iris.py +++ b/examples/gng_iris.py @@ -17,11 +17,12 @@ if __name__ == "__main__": # Prepare the data train_ds = pt.datasets.Iris(dims=[0, 2]) - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8) + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) # Hyperparameters hparams = dict( num_prototypes=5, + input_dim=2, lr=0.1, ) @@ -50,6 +51,3 @@ if __name__ == "__main__": # Training loop trainer.fit(model, train_loader) - - # Model summary - print(model) diff --git a/examples/ng_iris.py b/examples/ng_iris.py index 689690e..1ecee90 100644 --- a/examples/ng_iris.py +++ b/examples/ng_iris.py @@ -28,7 +28,11 @@ if __name__ == "__main__": train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) # Hyperparameters - hparams = dict(num_prototypes=30, lr=0.03) + hparams = dict( + num_prototypes=30, + input_dim=2, + lr=0.03, + ) # Initialize the model model = pt.models.NeuralGas( diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index 73aec78..2618c96 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -28,7 +28,6 @@ class NeuralGas(UnsupervisedPrototypeModel): self.save_hyperparameters(hparams) # Default hparams - self.hparams.setdefault("input_dim", 2) self.hparams.setdefault("agelimit", 10) self.hparams.setdefault("lm", 1)