diff --git a/examples/lgmlvq_moons.py b/examples/lgmlvq_moons.py index 5c053ad..cbdc9b4 100644 --- a/examples/lgmlvq_moons.py +++ b/examples/lgmlvq_moons.py @@ -12,12 +12,12 @@ if __name__ == "__main__": parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() - # Dataset - train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42) - # Reproducibility pl.utilities.seed.seed_everything(seed=2) + # Dataset + train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42) + # Dataloaders train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256, @@ -31,8 +31,10 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.LGMLVQ(hparams, - prototype_initializer=pt.components.SMI(train_ds)) + model = pt.models.LGMLVQ( + hparams, + prototypes_initializer=pt.initializers.SMCI(train_ds), + ) # Compute intermediate input and output sizes model.example_input_array = torch.zeros(4, 2)