Add similarities
This commit is contained in:
parent
d26a626677
commit
935d9fa7ad
@ -2,5 +2,8 @@
|
||||
|
||||
from .competitions import *
|
||||
from .components import *
|
||||
from .distances import *
|
||||
from .initializers import *
|
||||
from .losses import *
|
||||
from .pooling import *
|
||||
from .similarities import *
|
||||
|
19
prototorch/core/similarities.py
Normal file
19
prototorch/core/similarities.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""ProtoTorch similarities."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def cosine_similarity(x, y):
|
||||
"""Compute the cosine similarity between :math:`x` and :math:`y`.
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
norm_x = x.pow(2).sum(1).sqrt()
|
||||
norm_y = y.pow(2).sum(1).sqrt()
|
||||
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
||||
epsilon = torch.finfo(norm_mat.dtype).eps
|
||||
norm_mat.clamp_(min=epsilon)
|
||||
similarities = (x @ y.T) / norm_mat
|
||||
return similarities
|
Loading…
Reference in New Issue
Block a user