prototorch/prototorch/functions/competitions.py
2021-06-04 22:18:46 +02:00

29 lines
767 B
Python

"""ProtoTorch competition functions."""
import torch
def wtac(distances: torch.Tensor,
labels: torch.LongTensor) -> (torch.LongTensor):
"""Winner-Takes-All-Competition.
Returns the labels corresponding to the winners.
"""
winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze()
return winning_labels
def knnc(distances: torch.Tensor,
labels: torch.LongTensor,
k: int = 1) -> (torch.LongTensor):
"""K-Nearest-Neighbors-Competition.
Returns the labels corresponding to the winners.
"""
winning_indices = torch.topk(-distances, k=k, dim=1).indices
winning_labels = torch.mode(labels[winning_indices], dim=1).values
return winning_labels