2021-06-04 20:20:32 +00:00
|
|
|
"""ProtoTorch KNN model."""
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
|
2022-05-16 09:12:53 +00:00
|
|
|
from prototorch.core.competitions import KNNC
|
|
|
|
from prototorch.core.components import LabeledComponents
|
|
|
|
from prototorch.core.initializers import (
|
2022-01-11 17:28:50 +00:00
|
|
|
LiteralCompInitializer,
|
|
|
|
LiteralLabelsInitializer,
|
|
|
|
)
|
2022-05-16 09:12:53 +00:00
|
|
|
from prototorch.utils.utils import parse_data_arg
|
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
from .abstract import SupervisedPrototypeModel
|
|
|
|
|
|
|
|
|
|
|
|
class KNN(SupervisedPrototypeModel):
|
|
|
|
"""K-Nearest-Neighbors classification algorithm."""
|
2022-01-11 17:28:50 +00:00
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
2022-03-30 13:12:33 +00:00
|
|
|
super().__init__(hparams, skip_proto_layer=True, **kwargs)
|
2021-06-04 20:20:32 +00:00
|
|
|
|
|
|
|
# Default hparams
|
|
|
|
self.hparams.setdefault("k", 1)
|
|
|
|
|
|
|
|
data = kwargs.get("data", None)
|
|
|
|
if data is None:
|
|
|
|
raise ValueError("KNN requires data, but was not provided!")
|
2021-06-14 18:09:41 +00:00
|
|
|
data, targets = parse_data_arg(data)
|
2021-06-04 20:20:32 +00:00
|
|
|
|
|
|
|
# Layers
|
2021-06-14 18:09:41 +00:00
|
|
|
self.proto_layer = LabeledComponents(
|
2022-03-30 13:12:33 +00:00
|
|
|
distribution=len(data) * [1],
|
2021-06-14 18:09:41 +00:00
|
|
|
components_initializer=LiteralCompInitializer(data),
|
|
|
|
labels_initializer=LiteralLabelsInitializer(targets))
|
2021-06-04 20:20:32 +00:00
|
|
|
self.competition_layer = KNNC(k=self.hparams.k)
|
|
|
|
|
|
|
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
|
|
|
return 1 # skip training step
|
|
|
|
|
2022-05-17 10:00:52 +00:00
|
|
|
def on_train_batch_start(self, train_batch, batch_idx):
|
2021-06-04 20:20:32 +00:00
|
|
|
warnings.warn("k-NN has no training, skipping!")
|
|
|
|
return -1
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
return None
|