Update examples/glvq_iris.py script
This commit is contained in:
parent
42cedbb2b8
commit
3edb13baf4
@ -5,6 +5,7 @@ import torch
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from torchinfo import summary
|
||||||
|
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
@ -27,7 +28,7 @@ class Model(torch.nn.Module):
|
|||||||
input_dim=2,
|
input_dim=2,
|
||||||
prototypes_per_class=3,
|
prototypes_per_class=3,
|
||||||
nclasses=3,
|
nclasses=3,
|
||||||
prototype_initializer='stratified_random',
|
prototype_initializer="stratified_random",
|
||||||
data=[x_train, y_train])
|
data=[x_train, y_train])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -40,21 +41,24 @@ class Model(torch.nn.Module):
|
|||||||
# Build the GLVQ model
|
# Build the GLVQ model
|
||||||
model = Model()
|
model = Model()
|
||||||
|
|
||||||
|
# Print summary using torchinfo (might be buggy/incorrect)
|
||||||
|
print(summary(model))
|
||||||
|
|
||||||
# Optimize using SGD optimizer from `torch.optim`
|
# Optimize using SGD optimizer from `torch.optim`
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||||
|
|
||||||
x_in = torch.Tensor(x_train)
|
x_in = torch.Tensor(x_train)
|
||||||
y_in = torch.Tensor(y_train)
|
y_in = torch.Tensor(y_train)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
title = 'Prototype Visualization'
|
title = "Prototype Visualization"
|
||||||
fig = plt.figure(title)
|
fig = plt.figure(title)
|
||||||
for epoch in range(70):
|
for epoch in range(70):
|
||||||
# Compute loss
|
# Compute loss
|
||||||
dis, plabels = model(x_in)
|
dis, plabels = model(x_in)
|
||||||
loss = criterion([dis, plabels], y_in)
|
loss = criterion([dis, plabels], y_in)
|
||||||
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}')
|
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}")
|
||||||
|
|
||||||
# Take a gradient descent step
|
# Take a gradient descent step
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -64,23 +68,23 @@ for epoch in range(70):
|
|||||||
# Get the prototypes form the model
|
# Get the prototypes form the model
|
||||||
protos = model.proto_layer.prototypes.data.numpy()
|
protos = model.proto_layer.prototypes.data.numpy()
|
||||||
if np.isnan(np.sum(protos)):
|
if np.isnan(np.sum(protos)):
|
||||||
print('Stopping training because of `nan` in prototypes.')
|
print("Stopping training because of `nan` in prototypes.")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
# Visualize the data and the prototypes
|
||||||
ax = fig.gca()
|
ax = fig.gca()
|
||||||
ax.cla()
|
ax.cla()
|
||||||
ax.set_title(title)
|
ax.set_title(title)
|
||||||
ax.set_xlabel('Data dimension 1')
|
ax.set_xlabel("Data dimension 1")
|
||||||
ax.set_ylabel('Data dimension 2')
|
ax.set_ylabel("Data dimension 2")
|
||||||
cmap = 'viridis'
|
cmap = "viridis"
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||||
ax.scatter(protos[:, 0],
|
ax.scatter(protos[:, 0],
|
||||||
protos[:, 1],
|
protos[:, 1],
|
||||||
c=plabels,
|
c=plabels,
|
||||||
cmap=cmap,
|
cmap=cmap,
|
||||||
edgecolor='k',
|
edgecolor="k",
|
||||||
marker='D',
|
marker="D",
|
||||||
s=50)
|
s=50)
|
||||||
|
|
||||||
# Paint decision regions
|
# Paint decision regions
|
||||||
|
Loading…
Reference in New Issue
Block a user