Update examples/glvq_iris.py script
This commit is contained in:
		@@ -5,6 +5,7 @@ import torch
 | 
				
			|||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from sklearn.preprocessing import StandardScaler
 | 
					from sklearn.preprocessing import StandardScaler
 | 
				
			||||||
 | 
					from torchinfo import summary
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance
 | 
					from prototorch.functions.distances import euclidean_distance
 | 
				
			||||||
from prototorch.modules.losses import GLVQLoss
 | 
					from prototorch.modules.losses import GLVQLoss
 | 
				
			||||||
@@ -27,7 +28,7 @@ class Model(torch.nn.Module):
 | 
				
			|||||||
            input_dim=2,
 | 
					            input_dim=2,
 | 
				
			||||||
            prototypes_per_class=3,
 | 
					            prototypes_per_class=3,
 | 
				
			||||||
            nclasses=3,
 | 
					            nclasses=3,
 | 
				
			||||||
            prototype_initializer='stratified_random',
 | 
					            prototype_initializer="stratified_random",
 | 
				
			||||||
            data=[x_train, y_train])
 | 
					            data=[x_train, y_train])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
@@ -40,21 +41,24 @@ class Model(torch.nn.Module):
 | 
				
			|||||||
# Build the GLVQ model
 | 
					# Build the GLVQ model
 | 
				
			||||||
model = Model()
 | 
					model = Model()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Print summary using torchinfo (might be buggy/incorrect)
 | 
				
			||||||
 | 
					print(summary(model))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Optimize using SGD optimizer from `torch.optim`
 | 
					# Optimize using SGD optimizer from `torch.optim`
 | 
				
			||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
 | 
					optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
 | 
				
			||||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
 | 
					criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
x_in = torch.Tensor(x_train)
 | 
					x_in = torch.Tensor(x_train)
 | 
				
			||||||
y_in = torch.Tensor(y_train)
 | 
					y_in = torch.Tensor(y_train)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Training loop
 | 
					# Training loop
 | 
				
			||||||
title = 'Prototype Visualization'
 | 
					title = "Prototype Visualization"
 | 
				
			||||||
fig = plt.figure(title)
 | 
					fig = plt.figure(title)
 | 
				
			||||||
for epoch in range(70):
 | 
					for epoch in range(70):
 | 
				
			||||||
    # Compute loss
 | 
					    # Compute loss
 | 
				
			||||||
    dis, plabels = model(x_in)
 | 
					    dis, plabels = model(x_in)
 | 
				
			||||||
    loss = criterion([dis, plabels], y_in)
 | 
					    loss = criterion([dis, plabels], y_in)
 | 
				
			||||||
    print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}')
 | 
					    print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Take a gradient descent step
 | 
					    # Take a gradient descent step
 | 
				
			||||||
    optimizer.zero_grad()
 | 
					    optimizer.zero_grad()
 | 
				
			||||||
@@ -64,23 +68,23 @@ for epoch in range(70):
 | 
				
			|||||||
    # Get the prototypes form the model
 | 
					    # Get the prototypes form the model
 | 
				
			||||||
    protos = model.proto_layer.prototypes.data.numpy()
 | 
					    protos = model.proto_layer.prototypes.data.numpy()
 | 
				
			||||||
    if np.isnan(np.sum(protos)):
 | 
					    if np.isnan(np.sum(protos)):
 | 
				
			||||||
        print('Stopping training because of `nan` in prototypes.')
 | 
					        print("Stopping training because of `nan` in prototypes.")
 | 
				
			||||||
        break
 | 
					        break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Visualize the data and the prototypes
 | 
					    # Visualize the data and the prototypes
 | 
				
			||||||
    ax = fig.gca()
 | 
					    ax = fig.gca()
 | 
				
			||||||
    ax.cla()
 | 
					    ax.cla()
 | 
				
			||||||
    ax.set_title(title)
 | 
					    ax.set_title(title)
 | 
				
			||||||
    ax.set_xlabel('Data dimension 1')
 | 
					    ax.set_xlabel("Data dimension 1")
 | 
				
			||||||
    ax.set_ylabel('Data dimension 2')
 | 
					    ax.set_ylabel("Data dimension 2")
 | 
				
			||||||
    cmap = 'viridis'
 | 
					    cmap = "viridis"
 | 
				
			||||||
    ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
 | 
					    ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
				
			||||||
    ax.scatter(protos[:, 0],
 | 
					    ax.scatter(protos[:, 0],
 | 
				
			||||||
               protos[:, 1],
 | 
					               protos[:, 1],
 | 
				
			||||||
               c=plabels,
 | 
					               c=plabels,
 | 
				
			||||||
               cmap=cmap,
 | 
					               cmap=cmap,
 | 
				
			||||||
               edgecolor='k',
 | 
					               edgecolor="k",
 | 
				
			||||||
               marker='D',
 | 
					               marker="D",
 | 
				
			||||||
               s=50)
 | 
					               s=50)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Paint decision regions
 | 
					    # Paint decision regions
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user