diff --git a/examples/gng_iris.py b/examples/gng_iris.py index 8669f71..1781734 100644 --- a/examples/gng_iris.py +++ b/examples/gng_iris.py @@ -1,3 +1,5 @@ +"""Growing Neural Gas example using the Iris dataset.""" + import argparse import prototorch as pt @@ -13,12 +15,15 @@ if __name__ == "__main__": parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() + # Reproducibility + pl.utilities.seed.seed_everything(seed=42) + # Prepare the data train_ds = Iris(dims=[0, 2]) - train_loader = DataLoader(train_ds, batch_size=32) + train_loader = DataLoader(train_ds, batch_size=8) # Hyperparameters - hparams = dict(num_prototypes=2, + hparams = dict(num_prototypes=5, lr=0.1, prototype_initializer=SelectionInitializer(train_ds.data))