diff --git a/prototorch/functions/distances.py b/prototorch/functions/distances.py index 961cd69..14c0387 100644 --- a/prototorch/functions/distances.py +++ b/prototorch/functions/distances.py @@ -2,9 +2,8 @@ import numpy as np import torch - from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape, - equal_int_shape) + equal_int_shape, get_flat) def squared_euclidean_distance(x, y): @@ -12,12 +11,10 @@ def squared_euclidean_distance(x, y): Compute :math:`{\langle \bm x - \bm y \rangle}_2` - :param `torch.tensor` x: Two dimensional vector - :param `torch.tensor` y: Two dimensional vector - **Alias:** ``prototorch.functions.distances.sed`` """ + x, y = get_flat(x, y) expanded_x = x.unsqueeze(dim=1) batchwise_difference = y - expanded_x differences_raised = torch.pow(batchwise_difference, 2) @@ -30,18 +27,17 @@ def euclidean_distance(x, y): Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}` - :param `torch.tensor` x: Input Tensor of shape :math:`X \times N` - :param `torch.tensor` y: Input Tensor of shape :math:`Y \times N` - :returns: Distance Tensor of shape :math:`X \times Y` :rtype: `torch.tensor` """ + x, y = get_flat_x_y(x, y) distances_raised = squared_euclidean_distance(x, y) distances = torch.sqrt(distances_raised) return distances def euclidean_distance_v2(x, y): + x, y = get_flat(x, y) diff = y - x.unsqueeze(1) pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt() # Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the @@ -62,10 +58,9 @@ def lpnorm_distance(x, y, p): Calls ``torch.cdist`` - :param `torch.tensor` x: Two dimensional vector - :param `torch.tensor` y: Two dimensional vector :param p: p parameter of the lp norm """ + x, y = get_flat(x, y) distances = torch.cdist(x, y, p=p) return distances @@ -75,10 +70,9 @@ def omega_distance(x, y, omega): Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p` - :param `torch.tensor` x: Two dimensional vector - :param `torch.tensor` y: Two dimensional vector :param `torch.tensor` omega: Two dimensional matrix """ + x, y = get_flat(x, y) projected_x = x @ omega projected_y = y @ omega distances = squared_euclidean_distance(projected_x, projected_y) @@ -90,10 +84,9 @@ def lomega_distance(x, y, omegas): Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p` - :param `torch.tensor` x: Two dimensional vector - :param `torch.tensor` y: Two dimensional vector :param `torch.tensor` omegas: Three dimensional matrix """ + x, y = get_flat(x, y) projected_x = x @ omegas projected_y = torch.diagonal(y @ omegas).T expanded_y = torch.unsqueeze(projected_y, dim=1) diff --git a/prototorch/functions/helper.py b/prototorch/functions/helper.py index b157b78..6797a72 100644 --- a/prototorch/functions/helper.py +++ b/prototorch/functions/helper.py @@ -1,6 +1,11 @@ import torch +def get_flat(*args): + rv = [x.view(x.size(0), -1) for x in args] + return rv + + def calculate_prototype_accuracy(y_pred, y_true, plabels): """Computes the accuracy of a prototype based model. via Winner-Takes-All rule.