Update iris example
This commit is contained in:
parent
58efa5a4cf
commit
a8a99f6971
@ -20,8 +20,8 @@ x_train = scaler.transform(x_train)
|
|||||||
|
|
||||||
# Define the GLVQ model
|
# Define the GLVQ model
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self):
|
||||||
"""GLVQ model."""
|
"""GLVQ model for training on 2D Iris data."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proto_layer = Prototypes1D(
|
self.proto_layer = Prototypes1D(
|
||||||
input_dim=2,
|
input_dim=2,
|
||||||
@ -64,7 +64,7 @@ for epoch in range(70):
|
|||||||
# Get the prototypes form the model
|
# Get the prototypes form the model
|
||||||
protos = model.proto_layer.prototypes.data.numpy()
|
protos = model.proto_layer.prototypes.data.numpy()
|
||||||
if np.isnan(np.sum(protos)):
|
if np.isnan(np.sum(protos)):
|
||||||
print(f'Stopping because of `nan` in prototypes.')
|
print('Stopping training because of `nan` in prototypes.')
|
||||||
break
|
break
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
# Visualize the data and the prototypes
|
||||||
|
Loading…
Reference in New Issue
Block a user