diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index ed5d824..8d24991 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -20,8 +20,8 @@ x_train = scaler.transform(x_train) # Define the GLVQ model class Model(torch.nn.Module): - def __init__(self, **kwargs): - """GLVQ model.""" + def __init__(self): + """GLVQ model for training on 2D Iris data.""" super().__init__() self.proto_layer = Prototypes1D( input_dim=2, @@ -64,7 +64,7 @@ for epoch in range(70): # Get the prototypes form the model protos = model.proto_layer.prototypes.data.numpy() if np.isnan(np.sum(protos)): - print(f'Stopping because of `nan` in prototypes.') + print('Stopping training because of `nan` in prototypes.') break # Visualize the data and the prototypes