[BUGFIX] Fix knnc

This commit is contained in:
Jensun Ravichandran 2021-05-11 17:06:27 +02:00
parent 7bb93f027a
commit ae6bc47f87

View File

@ -3,7 +3,6 @@
import torch import torch
# @torch.jit.script
def stratified_min(distances, labels): def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0) clabels = torch.unique(labels, dim=0)
nclasses = clabels.size()[0] nclasses = clabels.size()[0]
@ -31,15 +30,14 @@ def stratified_min(distances, labels):
return winning_distances.T # return with `batch_size` first return winning_distances.T # return with `batch_size` first
# @torch.jit.script
def wtac(distances, labels): def wtac(distances, labels):
winning_indices = torch.min(distances, dim=1).indices winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze() winning_labels = labels[winning_indices].squeeze()
return winning_labels return winning_labels
# @torch.jit.script def knnc(distances, labels, k=1):
def knnc(distances, labels, k): winning_indices = torch.topk(-distances, k=k, dim=1).indices
winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices winning_labels = torch.mode(labels[winning_indices].squeeze(),
winning_labels = labels[winning_indices].squeeze() dim=1).values
return winning_labels return winning_labels