Modify stratified_min function

This commit is contained in:
blackfly 2020-04-27 12:48:12 +02:00
parent 532f63b1de
commit d17b9a3346

View File

@ -7,6 +7,9 @@ import torch
def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0)
nclasses = clabels.size()[0]
if distances.size()[1] == nclasses:
# skip if only one prototype per class
return distances
batch_size = distances.size()[0]
winning_distances = torch.zeros(nclasses, batch_size)
inf = torch.full_like(distances.T, fill_value=float('inf'))