prototorch/examples/lgmlvq_iris.py

109 lines
3.2 KiB
Python
Raw Normal View History

2020-09-24 14:59:42 +00:00
"""ProtoTorch LGMLVQ example using 2D Iris data."""
import numpy as np
import torch
from matplotlib import pyplot as plt
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
2020-09-24 14:59:42 +00:00
from prototorch.functions.competitions import stratified_min
from prototorch.functions.distances import lomega_distance
from prototorch.modules.losses import GLVQLoss
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
2020-09-24 14:59:42 +00:00
# Prepare training data
x_train, y_train = load_iris(True)
x_train = x_train[:, [0, 2]]
# Define the model
class Model(torch.nn.Module):
def __init__(self):
"""Local-GMLVQ model."""
super().__init__()
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
prototype_distribution = [1, 2, 2]
self.proto_layer = LabeledComponents(
prototype_distribution,
prototype_initializer,
2021-04-23 15:24:53 +00:00
)
omegas = torch.eye(2, 2).repeat(5, 1, 1)
2020-09-24 14:59:42 +00:00
self.omegas = torch.nn.Parameter(omegas)
def forward(self, x):
protos, plabels = self.proto_layer()
2020-09-24 14:59:42 +00:00
omegas = self.omegas
dis = lomega_distance(x, protos, omegas)
return dis, plabels
# Build the model
model = Model()
# Optimize using Adam optimizer from `torch.optim`
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train)
# Training loop
title = "Prototype Visualization"
fig = plt.figure(title)
for epoch in range(100):
# Compute loss
dis, plabels = model(x_in)
loss = criterion([dis, plabels], y_in)
y_pred = np.argmin(stratified_min(dis, plabels).detach().numpy(), axis=1)
acc = accuracy_score(y_train, y_pred)
log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} "
log_string += f"Acc: {acc * 100:05.02f}%"
print(log_string)
# Take a gradient descent step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Get the prototypes form the model
protos = model.proto_layer.components.numpy()
2020-09-24 14:59:42 +00:00
# Visualize the data and the prototypes
ax = fig.gca()
ax.cla()
ax.set_title(title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
cmap = "viridis"
2021-04-23 15:24:53 +00:00
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=cmap,
edgecolor="k",
marker="D",
s=50,
)
2020-09-24 14:59:42 +00:00
# Paint decision regions
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
d, plabels = model(torch.Tensor(mesh_input))
y_pred = np.argmin(stratified_min(d, plabels).detach().numpy(), axis=1)
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)