Add stratified_min competition function

This commit is contained in:
blackfly 2020-04-14 19:52:59 +02:00
parent 3cfbc49254
commit a3548e0ddd

View File

@ -3,6 +3,31 @@
import torch import torch
# @torch.jit.script
def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0)
nclasses = clabels.size()[0]
batch_size = distances.size()[0]
winning_distances = torch.zeros(nclasses, batch_size)
inf = torch.full_like(distances.T, fill_value=float('inf'))
# distances_to_wpluses = torch.where(matcher, distances, inf)
for i, cl in enumerate(clabels):
# cdists = distances.T[labels == cl]
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
if labels.ndim == 2:
# if the labels are one-hot vectors
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
cdists = torch.where(matcher, distances.T, inf).T
winning_distances[i] = torch.min(cdists, dim=1,
keepdim=True).values.squeeze()
if labels.ndim == 2:
# Transpose to return with `batch_size` first and
# reverse the columns to fix the ordering of the classes
return torch.flip(winning_distances.T, dims=(1, ))
return winning_distances.T # return with `batch_size` first
# @torch.jit.script # @torch.jit.script
def wtac(distances, labels): def wtac(distances, labels):
winning_indices = torch.min(distances, dim=1).indices winning_indices = torch.min(distances, dim=1).indices