prototorch_models/prototorch/models/knn.py

48 lines
1.4 KiB
Python
Raw Normal View History

2021-06-04 20:20:32 +00:00
"""ProtoTorch KNN model."""
import warnings
from ..core.competitions import KNNC
from ..core.components import LabeledComponents
from ..core.initializers import (
LiteralCompInitializer,
LiteralLabelsInitializer,
)
from ..utils.utils import parse_data_arg
2021-06-04 20:20:32 +00:00
from .abstract import SupervisedPrototypeModel
class KNN(SupervisedPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
2021-06-04 20:20:32 +00:00
def __init__(self, hparams, **kwargs):
2022-03-30 13:12:33 +00:00
super().__init__(hparams, skip_proto_layer=True, **kwargs)
2021-06-04 20:20:32 +00:00
# Default hparams
self.hparams.setdefault("k", 1)
data = kwargs.get("data", None)
if data is None:
raise ValueError("KNN requires data, but was not provided!")
2021-06-14 18:09:41 +00:00
data, targets = parse_data_arg(data)
2021-06-04 20:20:32 +00:00
# Layers
2021-06-14 18:09:41 +00:00
self.proto_layer = LabeledComponents(
2022-03-30 13:12:33 +00:00
distribution=len(data) * [1],
2021-06-14 18:09:41 +00:00
components_initializer=LiteralCompInitializer(data),
labels_initializer=LiteralLabelsInitializer(targets))
2021-06-04 20:20:32 +00:00
self.competition_layer = KNNC(k=self.hparams.k)
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1 # skip training step
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None