Update iris example

This commit is contained in:
Jensun Ravichandran 2021-03-26 16:06:11 +01:00
parent 3edb13baf4
commit 429570323e

View File

@ -3,13 +3,13 @@
import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
# Prepare and preprocess the data
scaler = StandardScaler()
@ -58,7 +58,11 @@ for epoch in range(70):
# Compute loss
dis, plabels = model(x_in)
loss = criterion([dis, plabels], y_in)
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}")
with torch.no_grad():
pred = wtac(dis, plabels)
correct = pred.eq(y_in.view_as(pred)).sum().item()
acc = 100. * correct / len(x_train)
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%")
# Take a gradient descent step
optimizer.zero_grad()