diff --git a/prototorch/functions/competitions.py b/prototorch/functions/competitions.py index 48cf78c..44c8248 100644 --- a/prototorch/functions/competitions.py +++ b/prototorch/functions/competitions.py @@ -3,6 +3,31 @@ import torch +# @torch.jit.script +def stratified_min(distances, labels): + clabels = torch.unique(labels, dim=0) + nclasses = clabels.size()[0] + batch_size = distances.size()[0] + winning_distances = torch.zeros(nclasses, batch_size) + inf = torch.full_like(distances.T, fill_value=float('inf')) + # distances_to_wpluses = torch.where(matcher, distances, inf) + for i, cl in enumerate(clabels): + # cdists = distances.T[labels == cl] + matcher = torch.eq(labels.unsqueeze(dim=1), cl) + if labels.ndim == 2: + # if the labels are one-hot vectors + matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses) + cdists = torch.where(matcher, distances.T, inf).T + winning_distances[i] = torch.min(cdists, dim=1, + keepdim=True).values.squeeze() + if labels.ndim == 2: + # Transpose to return with `batch_size` first and + # reverse the columns to fix the ordering of the classes + return torch.flip(winning_distances.T, dims=(1, )) + + return winning_distances.T # return with `batch_size` first + + # @torch.jit.script def wtac(distances, labels): winning_indices = torch.min(distances, dim=1).indices