Modify stratified_min function
This commit is contained in:
parent
532f63b1de
commit
d17b9a3346
@ -7,6 +7,9 @@ import torch
|
|||||||
def stratified_min(distances, labels):
|
def stratified_min(distances, labels):
|
||||||
clabels = torch.unique(labels, dim=0)
|
clabels = torch.unique(labels, dim=0)
|
||||||
nclasses = clabels.size()[0]
|
nclasses = clabels.size()[0]
|
||||||
|
if distances.size()[1] == nclasses:
|
||||||
|
# skip if only one prototype per class
|
||||||
|
return distances
|
||||||
batch_size = distances.size()[0]
|
batch_size = distances.size()[0]
|
||||||
winning_distances = torch.zeros(nclasses, batch_size)
|
winning_distances = torch.zeros(nclasses, batch_size)
|
||||||
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
||||||
|
Loading…
Reference in New Issue
Block a user