diff --git a/prototorch/functions/competitions.py b/prototorch/functions/competitions.py index 44c8248..bf0e988 100644 --- a/prototorch/functions/competitions.py +++ b/prototorch/functions/competitions.py @@ -7,6 +7,9 @@ import torch def stratified_min(distances, labels): clabels = torch.unique(labels, dim=0) nclasses = clabels.size()[0] + if distances.size()[1] == nclasses: + # skip if only one prototype per class + return distances batch_size = distances.size()[0] winning_distances = torch.zeros(nclasses, batch_size) inf = torch.full_like(distances.T, fill_value=float('inf'))