Update example script

This commit is contained in:
Jensun Ravichandran 2021-04-29 19:25:08 +02:00
parent ccaa52c408
commit e44516fc49

View File

@ -32,9 +32,9 @@ if __name__ == "__main__":
hparams = dict(
nclasses=2,
prototypes_per_class=20,
# prototype_initializer=cinit.SSI(torch.Tensor(x_train),
prototype_initializer=cinit.SMI(torch.Tensor(x_train),
torch.Tensor(y_train)),
prototype_initializer=cinit.SSI(torch.Tensor(x_train),
torch.Tensor(y_train),
noise=1e-7),
lr=0.01,
)
@ -42,8 +42,7 @@ if __name__ == "__main__":
model = GLVQ(hparams)
# Callbacks
vis = VisGLVQ2D(x_train, y_train)
# vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True)
vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True)
snan = StopOnNaN(model.proto_layer.components)
# Setup trainer