Add similarities
This commit is contained in:
		@@ -2,5 +2,8 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from .competitions import *
 | 
					from .competitions import *
 | 
				
			||||||
from .components import *
 | 
					from .components import *
 | 
				
			||||||
 | 
					from .distances import *
 | 
				
			||||||
from .initializers import *
 | 
					from .initializers import *
 | 
				
			||||||
from .losses import *
 | 
					from .losses import *
 | 
				
			||||||
 | 
					from .pooling import *
 | 
				
			||||||
 | 
					from .similarities import *
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										19
									
								
								prototorch/core/similarities.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								prototorch/core/similarities.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch similarities."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cosine_similarity(x, y):
 | 
				
			||||||
 | 
					    """Compute the cosine similarity between :math:`x` and :math:`y`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Expected dimension of x is 2.
 | 
				
			||||||
 | 
					    Expected dimension of y is 2.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
 | 
				
			||||||
 | 
					    norm_x = x.pow(2).sum(1).sqrt()
 | 
				
			||||||
 | 
					    norm_y = y.pow(2).sum(1).sqrt()
 | 
				
			||||||
 | 
					    norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
 | 
				
			||||||
 | 
					    epsilon = torch.finfo(norm_mat.dtype).eps
 | 
				
			||||||
 | 
					    norm_mat.clamp_(min=epsilon)
 | 
				
			||||||
 | 
					    similarities = (x @ y.T) / norm_mat
 | 
				
			||||||
 | 
					    return similarities
 | 
				
			||||||
		Reference in New Issue
	
	Block a user