40ef3aeda2
Update Iris example to new component API Update Tecator example to new component API Update LGMLVQ example to new component API Update GTLVQ to new component API
121 lines
3.6 KiB
Python
121 lines
3.6 KiB
Python
"""ProtoTorch GLVQ example using 2D Iris data."""
|
|
|
|
import numpy as np
|
|
import torch
|
|
from matplotlib import pyplot as plt
|
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
|
from prototorch.functions.competitions import wtac
|
|
from prototorch.functions.distances import euclidean_distance
|
|
from prototorch.modules.losses import GLVQLoss
|
|
from sklearn.datasets import load_iris
|
|
from sklearn.preprocessing import StandardScaler
|
|
from torchinfo import summary
|
|
|
|
# Prepare and preprocess the data
|
|
scaler = StandardScaler()
|
|
x_train, y_train = load_iris(return_X_y=True)
|
|
x_train = x_train[:, [0, 2]]
|
|
scaler.fit(x_train)
|
|
x_train = scaler.transform(x_train)
|
|
|
|
|
|
# Define the GLVQ model
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
"""GLVQ model for training on 2D Iris data."""
|
|
super().__init__()
|
|
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
|
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
|
|
self.proto_layer = LabeledComponents(
|
|
prototype_distribution,
|
|
prototype_initializer,
|
|
)
|
|
|
|
def forward(self, x):
|
|
prototypes, prototype_labels = self.proto_layer()
|
|
distances = euclidean_distance(x, prototypes)
|
|
return distances, prototype_labels
|
|
|
|
|
|
# Build the GLVQ model
|
|
model = Model()
|
|
|
|
# Print summary using torchinfo (might be buggy/incorrect)
|
|
print(summary(model))
|
|
|
|
# Optimize using SGD optimizer from `torch.optim`
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
|
|
|
x_in = torch.Tensor(x_train)
|
|
y_in = torch.Tensor(y_train)
|
|
|
|
# Training loop
|
|
TITLE = "Prototype Visualization"
|
|
fig = plt.figure(TITLE)
|
|
for epoch in range(70):
|
|
# Compute loss
|
|
distances, prototype_labels = model(x_in)
|
|
loss = criterion([distances, prototype_labels], y_in)
|
|
|
|
# Compute Accuracy
|
|
with torch.no_grad():
|
|
predictions = wtac(distances, prototype_labels)
|
|
correct = predictions.eq(y_in.view_as(predictions)).sum().item()
|
|
acc = 100.0 * correct / len(x_train)
|
|
|
|
print(
|
|
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
|
)
|
|
|
|
# Optimizer step
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Get the prototypes form the model
|
|
prototypes = model.proto_layer.components.numpy()
|
|
if np.isnan(np.sum(prototypes)):
|
|
print("Stopping training because of `nan` in prototypes.")
|
|
break
|
|
|
|
# Visualize the data and the prototypes
|
|
ax = fig.gca()
|
|
ax.cla()
|
|
ax.set_title(TITLE)
|
|
ax.set_xlabel("Data dimension 1")
|
|
ax.set_ylabel("Data dimension 2")
|
|
cmap = "viridis"
|
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
|
ax.scatter(
|
|
prototypes[:, 0],
|
|
prototypes[:, 1],
|
|
c=prototype_labels,
|
|
cmap=cmap,
|
|
edgecolor="k",
|
|
marker="D",
|
|
s=50,
|
|
)
|
|
|
|
# Paint decision regions
|
|
x = np.vstack((x_train, prototypes))
|
|
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
|
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
|
np.arange(y_min, y_max, 1 / 50))
|
|
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
|
|
torch_input = torch.Tensor(mesh_input)
|
|
d = model(torch_input)[0]
|
|
w_indices = torch.argmin(d, dim=1)
|
|
y_pred = torch.index_select(prototype_labels, 0, w_indices)
|
|
y_pred = y_pred.reshape(xx.shape)
|
|
|
|
# Plot voronoi regions
|
|
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
|
|
|
|
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
|
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
|
|
|
plt.pause(0.1)
|