Modify stratified_min function
This commit is contained in:
parent
532f63b1de
commit
d17b9a3346
@ -7,6 +7,9 @@ import torch
|
||||
def stratified_min(distances, labels):
|
||||
clabels = torch.unique(labels, dim=0)
|
||||
nclasses = clabels.size()[0]
|
||||
if distances.size()[1] == nclasses:
|
||||
# skip if only one prototype per class
|
||||
return distances
|
||||
batch_size = distances.size()[0]
|
||||
winning_distances = torch.zeros(nclasses, batch_size)
|
||||
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
||||
|
Loading…
Reference in New Issue
Block a user