Update examples/glvq_iris.py to use the recently modified API
This commit is contained in:
parent
1ec7bd261b
commit
6090aad176
@ -1,4 +1,4 @@
|
||||
"""ProtoTorch GLVQ example using 2D Iris data"""
|
||||
"""ProtoTorch GLVQ example using 2D Iris data."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -8,7 +8,7 @@ from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import AddPrototypes1D
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
# Prepare and preprocess the data
|
||||
scaler = StandardScaler()
|
||||
@ -22,10 +22,10 @@ x_train = scaler.transform(x_train)
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.p1 = AddPrototypes1D(input_dim=2,
|
||||
prototypes_per_class=1,
|
||||
nclasses=3,
|
||||
prototype_initializer='zeros')
|
||||
self.p1 = Prototypes1D(input_dim=2,
|
||||
prototypes_per_class=1,
|
||||
nclasses=3,
|
||||
prototype_initializer='zeros')
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
|
Loading…
Reference in New Issue
Block a user