Update KNN
This commit is contained in:
parent
77b7b59bad
commit
7a87636ad7
@ -5,7 +5,7 @@ import warnings
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.components.initializers import parse_init_arg
|
||||
from prototorch.components.initializers import parse_data_arg
|
||||
from prototorch.functions.competitions import knnc
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
|
||||
@ -24,7 +24,7 @@ class KNN(AbstractPrototypeModel):
|
||||
self.hparams.setdefault("distance", euclidean_distance)
|
||||
|
||||
data = kwargs.get("data")
|
||||
x_train, y_train = parse_init_arg(data)
|
||||
x_train, y_train = parse_data_arg(data)
|
||||
|
||||
self.proto_layer = LabeledComponents(initialized_components=(x_train,
|
||||
y_train))
|
||||
|
Loading…
Reference in New Issue
Block a user