Update examples/glvq_iris.py script

This commit is contained in:
Jensun Ravichandran 2021-03-01 18:52:54 +01:00
parent 42cedbb2b8
commit 3edb13baf4
2 changed files with 16 additions and 11 deletions

View File

@ -5,6 +5,7 @@ import torch
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss
@ -27,7 +28,7 @@ class Model(torch.nn.Module):
input_dim=2,
prototypes_per_class=3,
nclasses=3,
prototype_initializer='stratified_random',
prototype_initializer="stratified_random",
data=[x_train, y_train])
def forward(self, x):
@ -40,21 +41,24 @@ class Model(torch.nn.Module):
# Build the GLVQ model
model = Model()
# Print summary using torchinfo (might be buggy/incorrect)
print(summary(model))
# Optimize using SGD optimizer from `torch.optim`
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train)
# Training loop
title = 'Prototype Visualization'
title = "Prototype Visualization"
fig = plt.figure(title)
for epoch in range(70):
# Compute loss
dis, plabels = model(x_in)
loss = criterion([dis, plabels], y_in)
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}')
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}")
# Take a gradient descent step
optimizer.zero_grad()
@ -64,23 +68,23 @@ for epoch in range(70):
# Get the prototypes form the model
protos = model.proto_layer.prototypes.data.numpy()
if np.isnan(np.sum(protos)):
print('Stopping training because of `nan` in prototypes.')
print("Stopping training because of `nan` in prototypes.")
break
# Visualize the data and the prototypes
ax = fig.gca()
ax.cla()
ax.set_title(title)
ax.set_xlabel('Data dimension 1')
ax.set_ylabel('Data dimension 2')
cmap = 'viridis'
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(protos[:, 0],
protos[:, 1],
c=plabels,
cmap=cmap,
edgecolor='k',
marker='D',
edgecolor="k",
marker="D",
s=50)
# Paint decision regions

View File

@ -27,6 +27,7 @@ DATASETS = [
EXAMPLES = [
"sklearn",
"matplotlib",
"torchinfo",
]
TESTS = ["pytest"]
ALL = DOCS + DATASETS + EXAMPLES + TESTS