Update iris example

This commit is contained in:
Jensun Ravichandran 2020-09-24 11:54:18 +02:00
parent 58efa5a4cf
commit a8a99f6971

View File

@ -20,8 +20,8 @@ x_train = scaler.transform(x_train)
# Define the GLVQ model # Define the GLVQ model
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self, **kwargs): def __init__(self):
"""GLVQ model.""" """GLVQ model for training on 2D Iris data."""
super().__init__() super().__init__()
self.proto_layer = Prototypes1D( self.proto_layer = Prototypes1D(
input_dim=2, input_dim=2,
@ -64,7 +64,7 @@ for epoch in range(70):
# Get the prototypes form the model # Get the prototypes form the model
protos = model.proto_layer.prototypes.data.numpy() protos = model.proto_layer.prototypes.data.numpy()
if np.isnan(np.sum(protos)): if np.isnan(np.sum(protos)):
print(f'Stopping because of `nan` in prototypes.') print('Stopping training because of `nan` in prototypes.')
break break
# Visualize the data and the prototypes # Visualize the data and the prototypes