[BUGFIX] examples/siamese_glvq_iris.py works again

This commit is contained in:
Jensun Ravichandran 2021-06-14 20:44:36 +02:00
parent 1c658cdc1b
commit 24ebfdc667

View File

@ -2,11 +2,10 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
@ -52,7 +51,7 @@ if __name__ == "__main__":
# Initialize the model
model = pt.models.SiameseGLVQ(
hparams,
prototype_initializer=pt.components.SMI(train_ds),
prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone,
both_path_gradients=False,
)