prototorch_models/prototorch/models/knn.py
2021-06-04 22:20:32 +02:00

39 lines
1.1 KiB
Python

"""ProtoTorch KNN model."""
import warnings
from prototorch.components import LabeledComponents
from prototorch.modules import KNNC
from .abstract import SupervisedPrototypeModel
class KNN(SupervisedPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Default hparams
self.hparams.setdefault("k", 1)
data = kwargs.get("data", None)
if data is None:
raise ValueError("KNN requires data, but was not provided!")
# Layers
self.proto_layer = LabeledComponents(initialized_components=data)
self.competition_layer = KNNC(k=self.hparams.k)
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1 # skip training step
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None