Add stratified_min competition function
This commit is contained in:
parent
3cfbc49254
commit
a3548e0ddd
@ -3,6 +3,31 @@
|
||||
import torch
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def stratified_min(distances, labels):
|
||||
clabels = torch.unique(labels, dim=0)
|
||||
nclasses = clabels.size()[0]
|
||||
batch_size = distances.size()[0]
|
||||
winning_distances = torch.zeros(nclasses, batch_size)
|
||||
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
||||
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||
for i, cl in enumerate(clabels):
|
||||
# cdists = distances.T[labels == cl]
|
||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
cdists = torch.where(matcher, distances.T, inf).T
|
||||
winning_distances[i] = torch.min(cdists, dim=1,
|
||||
keepdim=True).values.squeeze()
|
||||
if labels.ndim == 2:
|
||||
# Transpose to return with `batch_size` first and
|
||||
# reverse the columns to fix the ordering of the classes
|
||||
return torch.flip(winning_distances.T, dims=(1, ))
|
||||
|
||||
return winning_distances.T # return with `batch_size` first
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def wtac(distances, labels):
|
||||
winning_indices = torch.min(distances, dim=1).indices
|
||||
|
Loading…
Reference in New Issue
Block a user