prototorch/examples/glvq_iris.py

121 lines
3.6 KiB
Python
Raw Normal View History

"""ProtoTorch GLVQ example using 2D Iris data."""
import numpy as np
import torch
from matplotlib import pyplot as plt
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
2021-03-26 15:06:11 +00:00
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss
2021-05-25 13:57:05 +00:00
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
# Prepare and preprocess the data
scaler = StandardScaler()
2020-07-30 09:19:02 +00:00
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
scaler.fit(x_train)
x_train = scaler.transform(x_train)
# Define the GLVQ model
class Model(torch.nn.Module):
2020-09-24 09:54:18 +00:00
def __init__(self):
"""GLVQ model for training on 2D Iris data."""
super().__init__()
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
self.proto_layer = LabeledComponents(
prototype_distribution,
prototype_initializer,
2021-04-23 15:24:53 +00:00
)
def forward(self, x):
prototypes, prototype_labels = self.proto_layer()
distances = euclidean_distance(x, prototypes)
return distances, prototype_labels
# Build the GLVQ model
model = Model()
2021-03-01 17:52:54 +00:00
# Print summary using torchinfo (might be buggy/incorrect)
print(summary(model))
# Optimize using SGD optimizer from `torch.optim`
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
2021-03-01 17:52:54 +00:00
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
2020-04-14 17:57:19 +00:00
x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train)
# Training loop
TITLE = "Prototype Visualization"
fig = plt.figure(TITLE)
for epoch in range(70):
2020-04-14 17:57:19 +00:00
# Compute loss
distances, prototype_labels = model(x_in)
loss = criterion([distances, prototype_labels], y_in)
# Compute Accuracy
2021-03-26 15:06:11 +00:00
with torch.no_grad():
predictions = wtac(distances, prototype_labels)
correct = predictions.eq(y_in.view_as(predictions)).sum().item()
2021-04-23 15:24:53 +00:00
acc = 100.0 * correct / len(x_train)
2021-04-23 15:24:53 +00:00
print(
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
)
# Optimizer step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Get the prototypes form the model
prototypes = model.proto_layer.components.numpy()
if np.isnan(np.sum(prototypes)):
2021-03-01 17:52:54 +00:00
print("Stopping training because of `nan` in prototypes.")
2020-09-23 13:29:26 +00:00
break
# Visualize the data and the prototypes
ax = fig.gca()
ax.cla()
ax.set_title(TITLE)
2021-03-01 17:52:54 +00:00
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
2021-04-23 15:24:53 +00:00
ax.scatter(
prototypes[:, 0],
prototypes[:, 1],
c=prototype_labels,
2021-04-23 15:24:53 +00:00
cmap=cmap,
edgecolor="k",
marker="D",
s=50,
)
# Paint decision regions
x = np.vstack((x_train, prototypes))
2020-04-14 17:57:19 +00:00
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
2020-04-14 17:57:19 +00:00
torch_input = torch.Tensor(mesh_input)
d = model(torch_input)[0]
2020-08-04 09:30:50 +00:00
w_indices = torch.argmin(d, dim=1)
y_pred = torch.index_select(prototype_labels, 0, w_indices)
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
2020-04-14 17:57:19 +00:00
plt.pause(0.1)