Update iris example
This commit is contained in:
		@@ -20,8 +20,8 @@ x_train = scaler.transform(x_train)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Define the GLVQ model
 | 
					# Define the GLVQ model
 | 
				
			||||||
class Model(torch.nn.Module):
 | 
					class Model(torch.nn.Module):
 | 
				
			||||||
    def __init__(self, **kwargs):
 | 
					    def __init__(self):
 | 
				
			||||||
        """GLVQ model."""
 | 
					        """GLVQ model for training on 2D Iris data."""
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.proto_layer = Prototypes1D(
 | 
					        self.proto_layer = Prototypes1D(
 | 
				
			||||||
            input_dim=2,
 | 
					            input_dim=2,
 | 
				
			||||||
@@ -64,7 +64,7 @@ for epoch in range(70):
 | 
				
			|||||||
    # Get the prototypes form the model
 | 
					    # Get the prototypes form the model
 | 
				
			||||||
    protos = model.proto_layer.prototypes.data.numpy()
 | 
					    protos = model.proto_layer.prototypes.data.numpy()
 | 
				
			||||||
    if np.isnan(np.sum(protos)):
 | 
					    if np.isnan(np.sum(protos)):
 | 
				
			||||||
        print(f'Stopping because of `nan` in prototypes.')
 | 
					        print('Stopping training because of `nan` in prototypes.')
 | 
				
			||||||
        break
 | 
					        break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Visualize the data and the prototypes
 | 
					    # Visualize the data and the prototypes
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user