From e87663d10cb4fd0cb84dd7b2e78b516fa8f08d87 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 13:07:30 +0200 Subject: [PATCH] Make siamese example script reproducible --- examples/siamese_glvq_iris.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index 8a4530f..d117f4f 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -2,18 +2,16 @@ import pytorch_lightning as pl import torch -from prototorch.components import ( - StratifiedMeanInitializer -) +from prototorch.components import initializers as cinit from prototorch.datasets.abstract import NumpyDataset +from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D +from prototorch.models.glvq import SiameseGLVQ from sklearn.datasets import load_iris from torch.utils.data import DataLoader -from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D -from prototorch.models.glvq import SiameseGLVQ - class Backbone(torch.nn.Module): + """Two fully connected layers with ReLU activation.""" def __init__(self, input_size=4, hidden_size=10, latent_size=2): super().__init__() self.input_size = input_size @@ -24,7 +22,9 @@ class Backbone(torch.nn.Module): self.relu = torch.nn.ReLU() def forward(self, x): - return self.relu(self.dense2(self.relu(self.dense1(x)))) + x = self.relu(self.dense1(x)) + out = self.relu(self.dense2(x)) + return out if __name__ == "__main__": @@ -32,16 +32,20 @@ if __name__ == "__main__": x_train, y_train = load_iris(return_X_y=True) train_ds = NumpyDataset(x_train, y_train) + # Reproducibility + pl.utilities.seed.seed_everything(seed=2) + # Dataloaders train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) # Hyperparameters hparams = dict( nclasses=3, - prototypes_per_class=1, - prototype_initializer=StratifiedMeanInitializer( - torch.Tensor(x_train), torch.Tensor(y_train)), - lr=0.01, + prototypes_per_class=2, + prototype_initializer=cinit.SMI(torch.Tensor(x_train), + torch.Tensor(y_train)), + proto_lr=0.001, + bb_lr=0.001, ) # Initialize the model @@ -54,7 +58,7 @@ if __name__ == "__main__": print(model) # Callbacks - vis = VisSiameseGLVQ2D(x_train, y_train) + vis = VisSiameseGLVQ2D(x_train, y_train, border=0.1) # Setup trainer trainer = pl.Trainer(max_epochs=100, callbacks=[vis])