Update example script
This commit is contained in:
parent
ccaa52c408
commit
e44516fc49
@ -32,9 +32,9 @@ if __name__ == "__main__":
|
|||||||
hparams = dict(
|
hparams = dict(
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=20,
|
prototypes_per_class=20,
|
||||||
# prototype_initializer=cinit.SSI(torch.Tensor(x_train),
|
prototype_initializer=cinit.SSI(torch.Tensor(x_train),
|
||||||
prototype_initializer=cinit.SMI(torch.Tensor(x_train),
|
torch.Tensor(y_train),
|
||||||
torch.Tensor(y_train)),
|
noise=1e-7),
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -42,8 +42,7 @@ if __name__ == "__main__":
|
|||||||
model = GLVQ(hparams)
|
model = GLVQ(hparams)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisGLVQ2D(x_train, y_train)
|
vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True)
|
||||||
# vis = VisGLVQ2D(x_train, y_train, show_last_only=True, block=True)
|
|
||||||
snan = StopOnNaN(model.proto_layer.components)
|
snan = StopOnNaN(model.proto_layer.components)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
|
Loading…
Reference in New Issue
Block a user