Update iris example
This commit is contained in:
parent
3edb13baf4
commit
429570323e
@ -3,13 +3,13 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torchinfo import summary
|
||||
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torchinfo import summary
|
||||
|
||||
# Prepare and preprocess the data
|
||||
scaler = StandardScaler()
|
||||
@ -58,7 +58,11 @@ for epoch in range(70):
|
||||
# Compute loss
|
||||
dis, plabels = model(x_in)
|
||||
loss = criterion([dis, plabels], y_in)
|
||||
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}")
|
||||
with torch.no_grad():
|
||||
pred = wtac(dis, plabels)
|
||||
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
||||
acc = 100. * correct / len(x_train)
|
||||
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%")
|
||||
|
||||
# Take a gradient descent step
|
||||
optimizer.zero_grad()
|
||||
|
Loading…
Reference in New Issue
Block a user