Update KNN

This commit is contained in:
Jensun Ravichandran 2021-05-17 16:59:35 +02:00
parent 77b7b59bad
commit 7a87636ad7

View File

@ -5,7 +5,7 @@ import warnings
import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.components.initializers import parse_init_arg
from prototorch.components.initializers import parse_data_arg
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance
@ -24,7 +24,7 @@ class KNN(AbstractPrototypeModel):
self.hparams.setdefault("distance", euclidean_distance)
data = kwargs.get("data")
x_train, y_train = parse_init_arg(data)
x_train, y_train = parse_data_arg(data)
self.proto_layer = LabeledComponents(initialized_components=(x_train,
y_train))