Refactor examples/glvq_iris.py
This commit is contained in:
parent
a0f20a40f6
commit
63a25e7a38
@ -42,13 +42,17 @@ model = Model()
|
|||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
||||||
|
|
||||||
|
x_in = torch.Tensor(x_train)
|
||||||
|
y_in = torch.Tensor(y_train)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
fig = plt.figure('Prototype Visualization')
|
title = 'Prototype Visualization'
|
||||||
|
fig = plt.figure(title)
|
||||||
for epoch in range(70):
|
for epoch in range(70):
|
||||||
# Compute loss.
|
# Compute loss
|
||||||
distances, plabels = model(torch.tensor(x_train))
|
dis, plabels = model(x_in)
|
||||||
loss = criterion([distances, plabels], torch.tensor(y_train))
|
loss = criterion([dis, plabels], y_in)
|
||||||
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():02.02f}')
|
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}')
|
||||||
|
|
||||||
# Take a gradient descent step
|
# Take a gradient descent step
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -61,6 +65,9 @@ for epoch in range(70):
|
|||||||
# Visualize the data and the prototypes
|
# Visualize the data and the prototypes
|
||||||
ax = fig.gca()
|
ax = fig.gca()
|
||||||
ax.cla()
|
ax.cla()
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.set_xlabel('Data dimension 1')
|
||||||
|
ax.set_ylabel('Data dimension 2')
|
||||||
cmap = 'viridis'
|
cmap = 'viridis'
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
||||||
ax.scatter(protos[:, 0],
|
ax.scatter(protos[:, 0],
|
||||||
@ -72,28 +79,17 @@ for epoch in range(70):
|
|||||||
s=50)
|
s=50)
|
||||||
|
|
||||||
# Paint decision regions
|
# Paint decision regions
|
||||||
border = 1
|
|
||||||
resolution = 50
|
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
x_min, x_max = x[:, 0].min(), x[:, 0].max()
|
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
||||||
y_min, y_max = x[:, 1].min(), x[:, 1].max()
|
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
||||||
x_min, x_max = x_min - border, x_max + border
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
||||||
y_min, y_max = y_min - border, y_max + border
|
np.arange(y_min, y_max, 1 / 50))
|
||||||
try:
|
|
||||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1.0 / resolution),
|
|
||||||
np.arange(y_min, y_max, 1.0 / resolution))
|
|
||||||
except ValueError as ve:
|
|
||||||
print(ve)
|
|
||||||
raise ValueError(f'x_min: {x_min}, x_max: {x_max}. '
|
|
||||||
f'x_min - x_max is {x_max - x_min}.')
|
|
||||||
except MemoryError as me:
|
|
||||||
print(me)
|
|
||||||
raise ValueError('Too many points. ' 'Try reducing the resolution.')
|
|
||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||||
|
|
||||||
torch_input = torch.from_numpy(mesh_input)
|
torch_input = torch.Tensor(mesh_input)
|
||||||
d = model(torch_input)[0]
|
d = model(torch_input)[0]
|
||||||
y_pred = np.argmin(d.detach().numpy(), axis=1)
|
y_pred = np.argmin(d.detach().numpy(),
|
||||||
|
axis=1) # assume one prototype per class
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
y_pred = y_pred.reshape(xx.shape)
|
||||||
|
|
||||||
# Plot voronoi regions
|
# Plot voronoi regions
|
||||||
@ -101,4 +97,5 @@ for epoch in range(70):
|
|||||||
|
|
||||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||||
|
|
||||||
plt.pause(0.1)
|
plt.pause(0.1)
|
||||||
|
Loading…
Reference in New Issue
Block a user