Update KNN
This commit is contained in:
parent
77b7b59bad
commit
7a87636ad7
@ -5,7 +5,7 @@ import warnings
|
|||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.components import LabeledComponents
|
from prototorch.components import LabeledComponents
|
||||||
from prototorch.components.initializers import parse_init_arg
|
from prototorch.components.initializers import parse_data_arg
|
||||||
from prototorch.functions.competitions import knnc
|
from prototorch.functions.competitions import knnc
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ class KNN(AbstractPrototypeModel):
|
|||||||
self.hparams.setdefault("distance", euclidean_distance)
|
self.hparams.setdefault("distance", euclidean_distance)
|
||||||
|
|
||||||
data = kwargs.get("data")
|
data = kwargs.get("data")
|
||||||
x_train, y_train = parse_init_arg(data)
|
x_train, y_train = parse_data_arg(data)
|
||||||
|
|
||||||
self.proto_layer = LabeledComponents(initialized_components=(x_train,
|
self.proto_layer = LabeledComponents(initialized_components=(x_train,
|
||||||
y_train))
|
y_train))
|
||||||
|
Loading…
Reference in New Issue
Block a user