[WIP] Update CBC example

This commit is contained in:
Jensun Ravichandran 2021-06-02 00:45:33 +02:00
parent 88cfd5762e
commit d46fe4a393
2 changed files with 7 additions and 10 deletions

View File

@ -17,18 +17,15 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris(dims=[0, 2])
# Reproducibility
pl.utilities.seed.seed_everything(seed=3)
pl.utilities.seed.seed_everything(seed=42)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
# Hyperparameters
hparams = dict(
distribution=[3, 2, 2],
proto_lr=0.01,
bb_lr=0.01,
distribution=[2, 2, 2],
proto_lr=0.1,
)
# Initialize the model
@ -40,7 +37,7 @@ if __name__ == "__main__":
# Callbacks
vis = pt.models.VisCBC2D(data=train_ds,
title="CBC Iris Example",
resolution=300,
resolution=100,
axis_off=True)
# Setup trainer

View File

@ -16,9 +16,9 @@ def shift_activation(x):
return (x + 1.0) / 2.0
def euclidean_similarity(x, y, beta=3):
def euclidean_similarity(x, y, variance=1.0):
d = euclidean_distance(x, y)
return torch.exp(-d * beta)
return torch.exp(-(d * d) / (2 * variance))
class CosineSimilarity(torch.nn.Module):