[WIP] Update CBC implementation to use SiameseGLVQ

This commit is contained in:
Jensun Ravichandran
2021-05-20 17:36:00 +02:00
parent 49f9a12b5f
commit 88a34a06ef
3 changed files with 83 additions and 89 deletions

View File

@@ -6,13 +6,10 @@ import torch
if __name__ == "__main__":
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
train_ds = pt.datasets.Iris(dims=[0, 2])
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
pl.utilities.seed.seed_everything(seed=3)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
@@ -21,18 +18,19 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
num_components=5,
component_initializer=pt.components.SSI(train_ds, noise=0.01),
lr=0.01,
distribution=[3, 2, 2],
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = pt.models.CBC(hparams)
model = pt.models.CBC(
hparams,
prototype_initializer=pt.components.SSI(train_ds, noise=0.01),
)
# Callbacks
dvis = pt.models.VisCBC2D(data=(x_train, y_train),
dvis = pt.models.VisCBC2D(data=train_ds,
title="CBC Iris Example",
resolution=300,
axis_off=True)