Update examples/glvq_iris.py to use the recently modified API
This commit is contained in:
parent
1ec7bd261b
commit
6090aad176
@ -1,4 +1,4 @@
|
|||||||
"""ProtoTorch GLVQ example using 2D Iris data"""
|
"""ProtoTorch GLVQ example using 2D Iris data."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -8,7 +8,7 @@ from sklearn.preprocessing import StandardScaler
|
|||||||
|
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.prototypes import AddPrototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
# Prepare and preprocess the data
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
@ -22,10 +22,10 @@ x_train = scaler.transform(x_train)
|
|||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.p1 = AddPrototypes1D(input_dim=2,
|
self.p1 = Prototypes1D(input_dim=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
nclasses=3,
|
nclasses=3,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer='zeros')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.p1.prototypes
|
protos = self.p1.prototypes
|
||||||
|
Loading…
Reference in New Issue
Block a user