Update examples/glvq_iris.py to use the recently modified API

This commit is contained in:
blackfly 2020-04-11 14:29:06 +02:00
parent 1ec7bd261b
commit 6090aad176

View File

@ -1,4 +1,4 @@
"""ProtoTorch GLVQ example using 2D Iris data""" """ProtoTorch GLVQ example using 2D Iris data."""
import numpy as np import numpy as np
import torch import torch
@ -8,7 +8,7 @@ from sklearn.preprocessing import StandardScaler
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import AddPrototypes1D from prototorch.modules.prototypes import Prototypes1D
# Prepare and preprocess the data # Prepare and preprocess the data
scaler = StandardScaler() scaler = StandardScaler()
@ -22,10 +22,10 @@ x_train = scaler.transform(x_train)
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__() super().__init__()
self.p1 = AddPrototypes1D(input_dim=2, self.p1 = Prototypes1D(input_dim=2,
prototypes_per_class=1, prototypes_per_class=1,
nclasses=3, nclasses=3,
prototype_initializer='zeros') prototype_initializer='zeros')
def forward(self, x): def forward(self, x):
protos = self.p1.prototypes protos = self.p1.prototypes