[BUGFIX] examples/lgmlvq_moons.py works again

This commit is contained in:
Jensun Ravichandran 2021-06-14 20:34:46 +02:00
parent 6197d7d5d6
commit 1911d4b33e

View File

@ -12,12 +12,12 @@ if __name__ == "__main__":
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args() args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=2) pl.utilities.seed.seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=256, batch_size=256,
@ -31,8 +31,10 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.LGMLVQ(hparams, model = pt.models.LGMLVQ(
prototype_initializer=pt.components.SMI(train_ds)) hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
)
# Compute intermediate input and output sizes # Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)