Add stratified_min competition function
This commit is contained in:
parent
3cfbc49254
commit
a3548e0ddd
@ -3,6 +3,31 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
|
def stratified_min(distances, labels):
|
||||||
|
clabels = torch.unique(labels, dim=0)
|
||||||
|
nclasses = clabels.size()[0]
|
||||||
|
batch_size = distances.size()[0]
|
||||||
|
winning_distances = torch.zeros(nclasses, batch_size)
|
||||||
|
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
||||||
|
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||||
|
for i, cl in enumerate(clabels):
|
||||||
|
# cdists = distances.T[labels == cl]
|
||||||
|
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# if the labels are one-hot vectors
|
||||||
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
|
cdists = torch.where(matcher, distances.T, inf).T
|
||||||
|
winning_distances[i] = torch.min(cdists, dim=1,
|
||||||
|
keepdim=True).values.squeeze()
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# Transpose to return with `batch_size` first and
|
||||||
|
# reverse the columns to fix the ordering of the classes
|
||||||
|
return torch.flip(winning_distances.T, dims=(1, ))
|
||||||
|
|
||||||
|
return winning_distances.T # return with `batch_size` first
|
||||||
|
|
||||||
|
|
||||||
# @torch.jit.script
|
# @torch.jit.script
|
||||||
def wtac(distances, labels):
|
def wtac(distances, labels):
|
||||||
winning_indices = torch.min(distances, dim=1).indices
|
winning_indices = torch.min(distances, dim=1).indices
|
||||||
|
Loading…
Reference in New Issue
Block a user