[BUGFIX] Fix knnc
This commit is contained in:
		@@ -3,7 +3,6 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @torch.jit.script
 | 
					 | 
				
			||||||
def stratified_min(distances, labels):
 | 
					def stratified_min(distances, labels):
 | 
				
			||||||
    clabels = torch.unique(labels, dim=0)
 | 
					    clabels = torch.unique(labels, dim=0)
 | 
				
			||||||
    nclasses = clabels.size()[0]
 | 
					    nclasses = clabels.size()[0]
 | 
				
			||||||
@@ -31,15 +30,14 @@ def stratified_min(distances, labels):
 | 
				
			|||||||
    return winning_distances.T  # return with `batch_size` first
 | 
					    return winning_distances.T  # return with `batch_size` first
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @torch.jit.script
 | 
					 | 
				
			||||||
def wtac(distances, labels):
 | 
					def wtac(distances, labels):
 | 
				
			||||||
    winning_indices = torch.min(distances, dim=1).indices
 | 
					    winning_indices = torch.min(distances, dim=1).indices
 | 
				
			||||||
    winning_labels = labels[winning_indices].squeeze()
 | 
					    winning_labels = labels[winning_indices].squeeze()
 | 
				
			||||||
    return winning_labels
 | 
					    return winning_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @torch.jit.script
 | 
					def knnc(distances, labels, k=1):
 | 
				
			||||||
def knnc(distances, labels, k):
 | 
					    winning_indices = torch.topk(-distances, k=k, dim=1).indices
 | 
				
			||||||
    winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
 | 
					    winning_labels = torch.mode(labels[winning_indices].squeeze(),
 | 
				
			||||||
    winning_labels = labels[winning_indices].squeeze()
 | 
					                                dim=1).values
 | 
				
			||||||
    return winning_labels
 | 
					    return winning_labels
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user