Update examples/glvq_iris.py script

This commit is contained in:
Jensun Ravichandran 2021-03-01 18:52:54 +01:00
parent 42cedbb2b8
commit 3edb13baf4
2 changed files with 16 additions and 11 deletions

View File

@ -5,6 +5,7 @@ import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from torchinfo import summary
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
@ -27,7 +28,7 @@ class Model(torch.nn.Module):
input_dim=2, input_dim=2,
prototypes_per_class=3, prototypes_per_class=3,
nclasses=3, nclasses=3,
prototype_initializer='stratified_random', prototype_initializer="stratified_random",
data=[x_train, y_train]) data=[x_train, y_train])
def forward(self, x): def forward(self, x):
@ -40,21 +41,24 @@ class Model(torch.nn.Module):
# Build the GLVQ model # Build the GLVQ model
model = Model() model = Model()
# Print summary using torchinfo (might be buggy/incorrect)
print(summary(model))
# Optimize using SGD optimizer from `torch.optim` # Optimize using SGD optimizer from `torch.optim`
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10) criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
x_in = torch.Tensor(x_train) x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train) y_in = torch.Tensor(y_train)
# Training loop # Training loop
title = 'Prototype Visualization' title = "Prototype Visualization"
fig = plt.figure(title) fig = plt.figure(title)
for epoch in range(70): for epoch in range(70):
# Compute loss # Compute loss
dis, plabels = model(x_in) dis, plabels = model(x_in)
loss = criterion([dis, plabels], y_in) loss = criterion([dis, plabels], y_in)
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}') print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}")
# Take a gradient descent step # Take a gradient descent step
optimizer.zero_grad() optimizer.zero_grad()
@ -64,23 +68,23 @@ for epoch in range(70):
# Get the prototypes form the model # Get the prototypes form the model
protos = model.proto_layer.prototypes.data.numpy() protos = model.proto_layer.prototypes.data.numpy()
if np.isnan(np.sum(protos)): if np.isnan(np.sum(protos)):
print('Stopping training because of `nan` in prototypes.') print("Stopping training because of `nan` in prototypes.")
break break
# Visualize the data and the prototypes # Visualize the data and the prototypes
ax = fig.gca() ax = fig.gca()
ax.cla() ax.cla()
ax.set_title(title) ax.set_title(title)
ax.set_xlabel('Data dimension 1') ax.set_xlabel("Data dimension 1")
ax.set_ylabel('Data dimension 2') ax.set_ylabel("Data dimension 2")
cmap = 'viridis' cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k') ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(protos[:, 0], ax.scatter(protos[:, 0],
protos[:, 1], protos[:, 1],
c=plabels, c=plabels,
cmap=cmap, cmap=cmap,
edgecolor='k', edgecolor="k",
marker='D', marker="D",
s=50) s=50)
# Paint decision regions # Paint decision regions

View File

@ -27,6 +27,7 @@ DATASETS = [
EXAMPLES = [ EXAMPLES = [
"sklearn", "sklearn",
"matplotlib", "matplotlib",
"torchinfo",
] ]
TESTS = ["pytest"] TESTS = ["pytest"]
ALL = DOCS + DATASETS + EXAMPLES + TESTS ALL = DOCS + DATASETS + EXAMPLES + TESTS