Flatten tensors before computing distances
This commit is contained in:
parent
abe64cfe8f
commit
acd4ac6a86
@ -2,9 +2,8 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||
equal_int_shape)
|
||||
equal_int_shape, get_flat)
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
@ -12,12 +11,10 @@ def squared_euclidean_distance(x, y):
|
||||
|
||||
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
||||
|
||||
:param `torch.tensor` x: Two dimensional vector
|
||||
:param `torch.tensor` y: Two dimensional vector
|
||||
|
||||
**Alias:**
|
||||
``prototorch.functions.distances.sed``
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
expanded_x = x.unsqueeze(dim=1)
|
||||
batchwise_difference = y - expanded_x
|
||||
differences_raised = torch.pow(batchwise_difference, 2)
|
||||
@ -30,18 +27,17 @@ def euclidean_distance(x, y):
|
||||
|
||||
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
||||
|
||||
:param `torch.tensor` x: Input Tensor of shape :math:`X \times N`
|
||||
:param `torch.tensor` y: Input Tensor of shape :math:`Y \times N`
|
||||
|
||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = get_flat_x_y(x, y)
|
||||
distances_raised = squared_euclidean_distance(x, y)
|
||||
distances = torch.sqrt(distances_raised)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_v2(x, y):
|
||||
x, y = get_flat(x, y)
|
||||
diff = y - x.unsqueeze(1)
|
||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||
@ -62,10 +58,9 @@ def lpnorm_distance(x, y, p):
|
||||
|
||||
Calls ``torch.cdist``
|
||||
|
||||
:param `torch.tensor` x: Two dimensional vector
|
||||
:param `torch.tensor` y: Two dimensional vector
|
||||
:param p: p parameter of the lp norm
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
distances = torch.cdist(x, y, p=p)
|
||||
return distances
|
||||
|
||||
@ -75,10 +70,9 @@ def omega_distance(x, y, omega):
|
||||
|
||||
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
||||
|
||||
:param `torch.tensor` x: Two dimensional vector
|
||||
:param `torch.tensor` y: Two dimensional vector
|
||||
:param `torch.tensor` omega: Two dimensional matrix
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
projected_x = x @ omega
|
||||
projected_y = y @ omega
|
||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||
@ -90,10 +84,9 @@ def lomega_distance(x, y, omegas):
|
||||
|
||||
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
||||
|
||||
:param `torch.tensor` x: Two dimensional vector
|
||||
:param `torch.tensor` y: Two dimensional vector
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
projected_x = x @ omegas
|
||||
projected_y = torch.diagonal(y @ omegas).T
|
||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||
|
@ -1,6 +1,11 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_flat(*args):
|
||||
rv = [x.view(x.size(0), -1) for x in args]
|
||||
return rv
|
||||
|
||||
|
||||
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
||||
"""Computes the accuracy of a prototype based model.
|
||||
via Winner-Takes-All rule.
|
||||
|
Loading…
Reference in New Issue
Block a user