[WIP] Update CBC example

This commit is contained in:
Jensun Ravichandran 2021-06-02 00:45:33 +02:00
parent 88cfd5762e
commit d46fe4a393
2 changed files with 7 additions and 10 deletions

View File

@ -17,18 +17,15 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=3) pl.utilities.seed.seed_everything(seed=42)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
num_workers=0,
batch_size=150)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[3, 2, 2], distribution=[2, 2, 2],
proto_lr=0.01, proto_lr=0.1,
bb_lr=0.01,
) )
# Initialize the model # Initialize the model
@ -40,7 +37,7 @@ if __name__ == "__main__":
# Callbacks # Callbacks
vis = pt.models.VisCBC2D(data=train_ds, vis = pt.models.VisCBC2D(data=train_ds,
title="CBC Iris Example", title="CBC Iris Example",
resolution=300, resolution=100,
axis_off=True) axis_off=True)
# Setup trainer # Setup trainer

View File

@ -16,9 +16,9 @@ def shift_activation(x):
return (x + 1.0) / 2.0 return (x + 1.0) / 2.0
def euclidean_similarity(x, y, beta=3): def euclidean_similarity(x, y, variance=1.0):
d = euclidean_distance(x, y) d = euclidean_distance(x, y)
return torch.exp(-d * beta) return torch.exp(-(d * d) / (2 * variance))
class CosineSimilarity(torch.nn.Module): class CosineSimilarity(torch.nn.Module):