Update KNN

This commit is contained in:
Jensun Ravichandran 2021-05-17 16:59:35 +02:00
parent 77b7b59bad
commit 7a87636ad7

View File

@ -5,7 +5,7 @@ import warnings
import torch import torch
import torchmetrics import torchmetrics
from prototorch.components import LabeledComponents from prototorch.components import LabeledComponents
from prototorch.components.initializers import parse_init_arg from prototorch.components.initializers import parse_data_arg
from prototorch.functions.competitions import knnc from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
@ -24,7 +24,7 @@ class KNN(AbstractPrototypeModel):
self.hparams.setdefault("distance", euclidean_distance) self.hparams.setdefault("distance", euclidean_distance)
data = kwargs.get("data") data = kwargs.get("data")
x_train, y_train = parse_init_arg(data) x_train, y_train = parse_data_arg(data)
self.proto_layer = LabeledComponents(initialized_components=(x_train, self.proto_layer = LabeledComponents(initialized_components=(x_train,
y_train)) y_train))