[BUG] Training unstable in examples/gng_iris.py

This commit is contained in:
Jensun Ravichandran 2021-06-02 00:21:42 +02:00
parent 98c198d463
commit 9eb6476078

View File

@ -1,3 +1,5 @@
"""Growing Neural Gas example using the Iris dataset."""
import argparse import argparse
import prototorch as pt import prototorch as pt
@ -13,12 +15,15 @@ if __name__ == "__main__":
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
# Prepare the data # Prepare the data
train_ds = Iris(dims=[0, 2]) train_ds = Iris(dims=[0, 2])
train_loader = DataLoader(train_ds, batch_size=32) train_loader = DataLoader(train_ds, batch_size=8)
# Hyperparameters # Hyperparameters
hparams = dict(num_prototypes=2, hparams = dict(num_prototypes=5,
lr=0.1, lr=0.1,
prototype_initializer=SelectionInitializer(train_ds.data)) prototype_initializer=SelectionInitializer(train_ds.data))