From 7d9dfc27ee2196394d97b0c45e84eb948e9f1de2 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 22 Apr 2021 13:12:19 +0200 Subject: [PATCH] Add similarities file --- prototorch/functions/similarities.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 prototorch/functions/similarities.py diff --git a/prototorch/functions/similarities.py b/prototorch/functions/similarities.py new file mode 100644 index 0000000..cc91c78 --- /dev/null +++ b/prototorch/functions/similarities.py @@ -0,0 +1,18 @@ +"""ProtoTorch similarity functions.""" + +import torch + + +def cosine_similarity(x, y): + """Compute the cosine similarity between :math:`x` and :math:`y`. + + Expected dimension of x is 2. + Expected dimension of y is 2. + """ + norm_x = x.pow(2).sum(1).sqrt() + norm_y = y.pow(2).sum(1).sqrt() + norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T + epsilon = torch.finfo(norm_mat.dtype).eps + norm_mat.clamp_(min=epsilon) + similarities = (x @ y.T) / norm_mat + return similarities