Update examples/glvq_iris.py script
This commit is contained in:
parent
42cedbb2b8
commit
3edb13baf4
@ -5,6 +5,7 @@ import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torchinfo import summary
|
||||
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
@ -27,7 +28,7 @@ class Model(torch.nn.Module):
|
||||
input_dim=2,
|
||||
prototypes_per_class=3,
|
||||
nclasses=3,
|
||||
prototype_initializer='stratified_random',
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x_train, y_train])
|
||||
|
||||
def forward(self, x):
|
||||
@ -40,21 +41,24 @@ class Model(torch.nn.Module):
|
||||
# Build the GLVQ model
|
||||
model = Model()
|
||||
|
||||
# Print summary using torchinfo (might be buggy/incorrect)
|
||||
print(summary(model))
|
||||
|
||||
# Optimize using SGD optimizer from `torch.optim`
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||
|
||||
x_in = torch.Tensor(x_train)
|
||||
y_in = torch.Tensor(y_train)
|
||||
|
||||
# Training loop
|
||||
title = 'Prototype Visualization'
|
||||
title = "Prototype Visualization"
|
||||
fig = plt.figure(title)
|
||||
for epoch in range(70):
|
||||
# Compute loss
|
||||
dis, plabels = model(x_in)
|
||||
loss = criterion([dis, plabels], y_in)
|
||||
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}')
|
||||
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}")
|
||||
|
||||
# Take a gradient descent step
|
||||
optimizer.zero_grad()
|
||||
@ -64,23 +68,23 @@ for epoch in range(70):
|
||||
# Get the prototypes form the model
|
||||
protos = model.proto_layer.prototypes.data.numpy()
|
||||
if np.isnan(np.sum(protos)):
|
||||
print('Stopping training because of `nan` in prototypes.')
|
||||
print("Stopping training because of `nan` in prototypes.")
|
||||
break
|
||||
|
||||
# Visualize the data and the prototypes
|
||||
ax = fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel('Data dimension 1')
|
||||
ax.set_ylabel('Data dimension 2')
|
||||
cmap = 'viridis'
|
||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
||||
ax.set_xlabel("Data dimension 1")
|
||||
ax.set_ylabel("Data dimension 2")
|
||||
cmap = "viridis"
|
||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||
ax.scatter(protos[:, 0],
|
||||
protos[:, 1],
|
||||
c=plabels,
|
||||
cmap=cmap,
|
||||
edgecolor='k',
|
||||
marker='D',
|
||||
edgecolor="k",
|
||||
marker="D",
|
||||
s=50)
|
||||
|
||||
# Paint decision regions
|
||||
|
Loading…
Reference in New Issue
Block a user