Add similarities

This commit is contained in:
Jensun Ravichandran 2021-06-12 20:50:04 +02:00
parent d26a626677
commit 935d9fa7ad
2 changed files with 22 additions and 0 deletions

View File

@ -2,5 +2,8 @@
from .competitions import * from .competitions import *
from .components import * from .components import *
from .distances import *
from .initializers import * from .initializers import *
from .losses import * from .losses import *
from .pooling import *
from .similarities import *

View File

@ -0,0 +1,19 @@
"""ProtoTorch similarities."""
import torch
def cosine_similarity(x, y):
"""Compute the cosine similarity between :math:`x` and :math:`y`.
Expected dimension of x is 2.
Expected dimension of y is 2.
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
norm_x = x.pow(2).sum(1).sqrt()
norm_y = y.pow(2).sum(1).sqrt()
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
epsilon = torch.finfo(norm_mat.dtype).eps
norm_mat.clamp_(min=epsilon)
similarities = (x @ y.T) / norm_mat
return similarities