diff --git a/examples/cbc_iris.py b/examples/cbc_iris.py index 92f0791..ae5c5ae 100644 --- a/examples/cbc_iris.py +++ b/examples/cbc_iris.py @@ -11,6 +11,9 @@ if __name__ == "__main__": x_train = x_train[:, [0, 2]] train_ds = pt.datasets.NumpyDataset(x_train, y_train) + # Reproducibility + pl.utilities.seed.seed_everything(seed=2) + # Dataloaders train_loader = torch.utils.data.DataLoader(train_ds, num_workers=0, @@ -20,8 +23,8 @@ if __name__ == "__main__": hparams = dict( input_dim=x_train.shape[1], nclasses=3, - num_components=9, - component_initializer=pt.components.SMI(train_ds), + num_components=5, + component_initializer=pt.components.SSI(train_ds, noise=0.01), lr=0.01, ) @@ -34,7 +37,7 @@ if __name__ == "__main__": # Setup trainer trainer = pl.Trainer( - max_epochs=50, + max_epochs=200, callbacks=[ dvis, ],