"""ProtoTorch GLVQ example using 2D Iris data.""" import numpy as np import torch from matplotlib import pyplot as plt from prototorch.components import LabeledComponents, StratifiedMeanInitializer from prototorch.functions.competitions import wtac from prototorch.functions.distances import euclidean_distance from prototorch.modules.losses import GLVQLoss from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler from torchinfo import summary # Prepare and preprocess the data scaler = StandardScaler() x_train, y_train = load_iris(return_X_y=True) x_train = x_train[:, [0, 2]] scaler.fit(x_train) x_train = scaler.transform(x_train) # Define the GLVQ model class Model(torch.nn.Module): def __init__(self): """GLVQ model for training on 2D Iris data.""" super().__init__() prototype_initializer = StratifiedMeanInitializer([x_train, y_train]) prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3} self.proto_layer = LabeledComponents( prototype_distribution, prototype_initializer, ) def forward(self, x): prototypes, prototype_labels = self.proto_layer() distances = euclidean_distance(x, prototypes) return distances, prototype_labels # Build the GLVQ model model = Model() # Print summary using torchinfo (might be buggy/incorrect) print(summary(model)) # Optimize using SGD optimizer from `torch.optim` optimizer = torch.optim.SGD(model.parameters(), lr=0.01) criterion = GLVQLoss(squashing="sigmoid_beta", beta=10) x_in = torch.Tensor(x_train) y_in = torch.Tensor(y_train) # Training loop TITLE = "Prototype Visualization" fig = plt.figure(TITLE) for epoch in range(70): # Compute loss distances, prototype_labels = model(x_in) loss = criterion([distances, prototype_labels], y_in) # Compute Accuracy with torch.no_grad(): predictions = wtac(distances, prototype_labels) correct = predictions.eq(y_in.view_as(predictions)).sum().item() acc = 100.0 * correct / len(x_train) print( f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%" ) # Optimizer step optimizer.zero_grad() loss.backward() optimizer.step() # Get the prototypes form the model prototypes = model.proto_layer.components.numpy() if np.isnan(np.sum(prototypes)): print("Stopping training because of `nan` in prototypes.") break # Visualize the data and the prototypes ax = fig.gca() ax.cla() ax.set_title(TITLE) ax.set_xlabel("Data dimension 1") ax.set_ylabel("Data dimension 2") cmap = "viridis" ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter( prototypes[:, 0], prototypes[:, 1], c=prototype_labels, cmap=cmap, edgecolor="k", marker="D", s=50, ) # Paint decision regions x = np.vstack((x_train, prototypes)) x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), np.arange(y_min, y_max, 1 / 50)) mesh_input = np.c_[xx.ravel(), yy.ravel()] torch_input = torch.Tensor(mesh_input) d = model(torch_input)[0] w_indices = torch.argmin(d, dim=1) y_pred = torch.index_select(prototype_labels, 0, w_indices) y_pred = y_pred.reshape(xx.shape) # Plot voronoi regions ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35) ax.set_xlim(left=x_min + 0, right=x_max - 0) ax.set_ylim(bottom=y_min + 0, top=y_max - 0) plt.pause(0.1)