Add similarities file

This commit is contained in:
Jensun Ravichandran 2021-04-22 13:12:19 +02:00
parent ae75b9ebf7
commit 7d9dfc27ee

View File

@ -0,0 +1,18 @@
"""ProtoTorch similarity functions."""
import torch
def cosine_similarity(x, y):
"""Compute the cosine similarity between :math:`x` and :math:`y`.
Expected dimension of x is 2.
Expected dimension of y is 2.
"""
norm_x = x.pow(2).sum(1).sqrt()
norm_y = y.pow(2).sum(1).sqrt()
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
epsilon = torch.finfo(norm_mat.dtype).eps
norm_mat.clamp_(min=epsilon)
similarities = (x @ y.T) / norm_mat
return similarities