2020-04-11 12:29:06 +00:00
|
|
|
"""ProtoTorch GLVQ example using 2D Iris data."""
|
2020-04-06 14:43:59 +00:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from matplotlib import pyplot as plt
|
2021-05-28 13:57:26 +00:00
|
|
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
2021-03-26 15:06:11 +00:00
|
|
|
from prototorch.functions.competitions import wtac
|
2020-04-06 14:43:59 +00:00
|
|
|
from prototorch.functions.distances import euclidean_distance
|
|
|
|
from prototorch.modules.losses import GLVQLoss
|
2021-05-25 13:57:05 +00:00
|
|
|
from sklearn.datasets import load_iris
|
|
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
from torchinfo import summary
|
2020-04-06 14:43:59 +00:00
|
|
|
|
|
|
|
# Prepare and preprocess the data
|
|
|
|
scaler = StandardScaler()
|
2020-07-30 09:19:02 +00:00
|
|
|
x_train, y_train = load_iris(return_X_y=True)
|
2020-04-06 14:43:59 +00:00
|
|
|
x_train = x_train[:, [0, 2]]
|
|
|
|
scaler.fit(x_train)
|
|
|
|
x_train = scaler.transform(x_train)
|
|
|
|
|
|
|
|
|
|
|
|
# Define the GLVQ model
|
|
|
|
class Model(torch.nn.Module):
|
2020-09-24 09:54:18 +00:00
|
|
|
def __init__(self):
|
|
|
|
"""GLVQ model for training on 2D Iris data."""
|
2020-04-06 14:43:59 +00:00
|
|
|
super().__init__()
|
2021-05-28 13:57:26 +00:00
|
|
|
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
|
|
|
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
|
|
|
|
self.proto_layer = LabeledComponents(
|
|
|
|
prototype_distribution,
|
|
|
|
prototype_initializer,
|
2021-04-23 15:24:53 +00:00
|
|
|
)
|
2020-04-06 14:43:59 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
2021-05-28 13:57:26 +00:00
|
|
|
prototypes, prototype_labels = self.proto_layer()
|
|
|
|
distances = euclidean_distance(x, prototypes)
|
|
|
|
return distances, prototype_labels
|
2020-04-06 14:43:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
# Build the GLVQ model
|
|
|
|
model = Model()
|
|
|
|
|
2021-03-01 17:52:54 +00:00
|
|
|
# Print summary using torchinfo (might be buggy/incorrect)
|
|
|
|
print(summary(model))
|
|
|
|
|
2020-04-06 14:43:59 +00:00
|
|
|
# Optimize using SGD optimizer from `torch.optim`
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
2021-03-01 17:52:54 +00:00
|
|
|
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
2020-04-06 14:43:59 +00:00
|
|
|
|
2020-04-14 17:57:19 +00:00
|
|
|
x_in = torch.Tensor(x_train)
|
|
|
|
y_in = torch.Tensor(y_train)
|
|
|
|
|
2020-04-06 14:43:59 +00:00
|
|
|
# Training loop
|
2021-05-28 13:57:26 +00:00
|
|
|
TITLE = "Prototype Visualization"
|
|
|
|
fig = plt.figure(TITLE)
|
2020-04-06 14:43:59 +00:00
|
|
|
for epoch in range(70):
|
2020-04-14 17:57:19 +00:00
|
|
|
# Compute loss
|
2021-05-28 13:57:26 +00:00
|
|
|
distances, prototype_labels = model(x_in)
|
|
|
|
loss = criterion([distances, prototype_labels], y_in)
|
|
|
|
|
|
|
|
# Compute Accuracy
|
2021-03-26 15:06:11 +00:00
|
|
|
with torch.no_grad():
|
2021-05-28 13:57:26 +00:00
|
|
|
predictions = wtac(distances, prototype_labels)
|
|
|
|
correct = predictions.eq(y_in.view_as(predictions)).sum().item()
|
2021-04-23 15:24:53 +00:00
|
|
|
acc = 100.0 * correct / len(x_train)
|
2021-05-28 13:57:26 +00:00
|
|
|
|
2021-04-23 15:24:53 +00:00
|
|
|
print(
|
|
|
|
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
|
|
|
)
|
2020-04-06 14:43:59 +00:00
|
|
|
|
2021-05-28 13:57:26 +00:00
|
|
|
# Optimizer step
|
2020-04-06 14:43:59 +00:00
|
|
|
optimizer.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
# Get the prototypes form the model
|
2021-05-28 13:57:26 +00:00
|
|
|
prototypes = model.proto_layer.components.numpy()
|
|
|
|
if np.isnan(np.sum(prototypes)):
|
2021-03-01 17:52:54 +00:00
|
|
|
print("Stopping training because of `nan` in prototypes.")
|
2020-09-23 13:29:26 +00:00
|
|
|
break
|
2020-04-06 14:43:59 +00:00
|
|
|
|
|
|
|
# Visualize the data and the prototypes
|
|
|
|
ax = fig.gca()
|
|
|
|
ax.cla()
|
2021-05-28 13:57:26 +00:00
|
|
|
ax.set_title(TITLE)
|
2021-03-01 17:52:54 +00:00
|
|
|
ax.set_xlabel("Data dimension 1")
|
|
|
|
ax.set_ylabel("Data dimension 2")
|
|
|
|
cmap = "viridis"
|
|
|
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
2021-04-23 15:24:53 +00:00
|
|
|
ax.scatter(
|
2021-05-28 13:57:26 +00:00
|
|
|
prototypes[:, 0],
|
|
|
|
prototypes[:, 1],
|
|
|
|
c=prototype_labels,
|
2021-04-23 15:24:53 +00:00
|
|
|
cmap=cmap,
|
|
|
|
edgecolor="k",
|
|
|
|
marker="D",
|
|
|
|
s=50,
|
|
|
|
)
|
2020-04-06 14:43:59 +00:00
|
|
|
|
|
|
|
# Paint decision regions
|
2021-05-28 13:57:26 +00:00
|
|
|
x = np.vstack((x_train, prototypes))
|
2020-04-14 17:57:19 +00:00
|
|
|
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
|
|
|
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
|
|
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
|
|
|
np.arange(y_min, y_max, 1 / 50))
|
2020-04-06 14:43:59 +00:00
|
|
|
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
|
|
|
2020-04-14 17:57:19 +00:00
|
|
|
torch_input = torch.Tensor(mesh_input)
|
2020-04-06 14:43:59 +00:00
|
|
|
d = model(torch_input)[0]
|
2020-08-04 09:30:50 +00:00
|
|
|
w_indices = torch.argmin(d, dim=1)
|
2021-05-28 13:57:26 +00:00
|
|
|
y_pred = torch.index_select(prototype_labels, 0, w_indices)
|
2020-04-06 14:43:59 +00:00
|
|
|
y_pred = y_pred.reshape(xx.shape)
|
|
|
|
|
|
|
|
# Plot voronoi regions
|
|
|
|
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
|
|
|
|
|
|
|
|
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
|
|
|
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
2020-04-14 17:57:19 +00:00
|
|
|
|
2020-04-06 14:43:59 +00:00
|
|
|
plt.pause(0.1)
|