Update example script

This commit is contained in:
Jensun Ravichandran 2021-04-29 19:25:08 +02:00
parent ccaa52c408
commit e44516fc49

View File

@ -32,9 +32,9 @@ if __name__ == "__main__":
hparams = dict( hparams = dict(
nclasses=2, nclasses=2,
prototypes_per_class=20, prototypes_per_class=20,
# prototype_initializer=cinit.SSI(torch.Tensor(x_train), prototype_initializer=cinit.SSI(torch.Tensor(x_train),
prototype_initializer=cinit.SMI(torch.Tensor(x_train), torch.Tensor(y_train),
torch.Tensor(y_train)), noise=1e-7),
lr=0.01, lr=0.01,
) )
@ -42,8 +42,7 @@ if __name__ == "__main__":
model = GLVQ(hparams) model = GLVQ(hparams)
# Callbacks # Callbacks
vis = VisGLVQ2D(x_train, y_train) vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True)
# vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True)
snan = StopOnNaN(model.proto_layer.components) snan = StopOnNaN(model.proto_layer.components)
# Setup trainer # Setup trainer