Update iris example
This commit is contained in:
parent
3edb13baf4
commit
429570323e
@ -3,13 +3,13 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from sklearn.datasets import load_iris
|
from prototorch.functions.competitions import wtac
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
from torchinfo import summary
|
|
||||||
|
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from torchinfo import summary
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
# Prepare and preprocess the data
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
@ -58,7 +58,11 @@ for epoch in range(70):
|
|||||||
# Compute loss
|
# Compute loss
|
||||||
dis, plabels = model(x_in)
|
dis, plabels = model(x_in)
|
||||||
loss = criterion([dis, plabels], y_in)
|
loss = criterion([dis, plabels], y_in)
|
||||||
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}")
|
with torch.no_grad():
|
||||||
|
pred = wtac(dis, plabels)
|
||||||
|
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
||||||
|
acc = 100. * correct / len(x_train)
|
||||||
|
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%")
|
||||||
|
|
||||||
# Take a gradient descent step
|
# Take a gradient descent step
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
Loading…
Reference in New Issue
Block a user