diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index fd23fee..9b14cbc 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -3,13 +3,13 @@ import numpy as np import torch from matplotlib import pyplot as plt -from sklearn.datasets import load_iris -from sklearn.preprocessing import StandardScaler -from torchinfo import summary - +from prototorch.functions.competitions import wtac from prototorch.functions.distances import euclidean_distance from prototorch.modules.losses import GLVQLoss from prototorch.modules.prototypes import Prototypes1D +from sklearn.datasets import load_iris +from sklearn.preprocessing import StandardScaler +from torchinfo import summary # Prepare and preprocess the data scaler = StandardScaler() @@ -58,7 +58,11 @@ for epoch in range(70): # Compute loss dis, plabels = model(x_in) loss = criterion([dis, plabels], y_in) - print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}") + with torch.no_grad(): + pred = wtac(dis, plabels) + correct = pred.eq(y_in.view_as(pred)).sum().item() + acc = 100. * correct / len(x_train) + print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%") # Take a gradient descent step optimizer.zero_grad()