Flatten tensors before computing distances
This commit is contained in:
parent
abe64cfe8f
commit
acd4ac6a86
@ -2,9 +2,8 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||||
equal_int_shape)
|
equal_int_shape, get_flat)
|
||||||
|
|
||||||
|
|
||||||
def squared_euclidean_distance(x, y):
|
def squared_euclidean_distance(x, y):
|
||||||
@ -12,12 +11,10 @@ def squared_euclidean_distance(x, y):
|
|||||||
|
|
||||||
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
||||||
|
|
||||||
:param `torch.tensor` x: Two dimensional vector
|
|
||||||
:param `torch.tensor` y: Two dimensional vector
|
|
||||||
|
|
||||||
**Alias:**
|
**Alias:**
|
||||||
``prototorch.functions.distances.sed``
|
``prototorch.functions.distances.sed``
|
||||||
"""
|
"""
|
||||||
|
x, y = get_flat(x, y)
|
||||||
expanded_x = x.unsqueeze(dim=1)
|
expanded_x = x.unsqueeze(dim=1)
|
||||||
batchwise_difference = y - expanded_x
|
batchwise_difference = y - expanded_x
|
||||||
differences_raised = torch.pow(batchwise_difference, 2)
|
differences_raised = torch.pow(batchwise_difference, 2)
|
||||||
@ -30,18 +27,17 @@ def euclidean_distance(x, y):
|
|||||||
|
|
||||||
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
||||||
|
|
||||||
:param `torch.tensor` x: Input Tensor of shape :math:`X \times N`
|
|
||||||
:param `torch.tensor` y: Input Tensor of shape :math:`Y \times N`
|
|
||||||
|
|
||||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||||
:rtype: `torch.tensor`
|
:rtype: `torch.tensor`
|
||||||
"""
|
"""
|
||||||
|
x, y = get_flat_x_y(x, y)
|
||||||
distances_raised = squared_euclidean_distance(x, y)
|
distances_raised = squared_euclidean_distance(x, y)
|
||||||
distances = torch.sqrt(distances_raised)
|
distances = torch.sqrt(distances_raised)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
def euclidean_distance_v2(x, y):
|
def euclidean_distance_v2(x, y):
|
||||||
|
x, y = get_flat(x, y)
|
||||||
diff = y - x.unsqueeze(1)
|
diff = y - x.unsqueeze(1)
|
||||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||||
@ -62,10 +58,9 @@ def lpnorm_distance(x, y, p):
|
|||||||
|
|
||||||
Calls ``torch.cdist``
|
Calls ``torch.cdist``
|
||||||
|
|
||||||
:param `torch.tensor` x: Two dimensional vector
|
|
||||||
:param `torch.tensor` y: Two dimensional vector
|
|
||||||
:param p: p parameter of the lp norm
|
:param p: p parameter of the lp norm
|
||||||
"""
|
"""
|
||||||
|
x, y = get_flat(x, y)
|
||||||
distances = torch.cdist(x, y, p=p)
|
distances = torch.cdist(x, y, p=p)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
@ -75,10 +70,9 @@ def omega_distance(x, y, omega):
|
|||||||
|
|
||||||
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
||||||
|
|
||||||
:param `torch.tensor` x: Two dimensional vector
|
|
||||||
:param `torch.tensor` y: Two dimensional vector
|
|
||||||
:param `torch.tensor` omega: Two dimensional matrix
|
:param `torch.tensor` omega: Two dimensional matrix
|
||||||
"""
|
"""
|
||||||
|
x, y = get_flat(x, y)
|
||||||
projected_x = x @ omega
|
projected_x = x @ omega
|
||||||
projected_y = y @ omega
|
projected_y = y @ omega
|
||||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||||
@ -90,10 +84,9 @@ def lomega_distance(x, y, omegas):
|
|||||||
|
|
||||||
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
||||||
|
|
||||||
:param `torch.tensor` x: Two dimensional vector
|
|
||||||
:param `torch.tensor` y: Two dimensional vector
|
|
||||||
:param `torch.tensor` omegas: Three dimensional matrix
|
:param `torch.tensor` omegas: Three dimensional matrix
|
||||||
"""
|
"""
|
||||||
|
x, y = get_flat(x, y)
|
||||||
projected_x = x @ omegas
|
projected_x = x @ omegas
|
||||||
projected_y = torch.diagonal(y @ omegas).T
|
projected_y = torch.diagonal(y @ omegas).T
|
||||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_flat(*args):
|
||||||
|
rv = [x.view(x.size(0), -1) for x in args]
|
||||||
|
return rv
|
||||||
|
|
||||||
|
|
||||||
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
||||||
"""Computes the accuracy of a prototype based model.
|
"""Computes the accuracy of a prototype based model.
|
||||||
via Winner-Takes-All rule.
|
via Winner-Takes-All rule.
|
||||||
|
Loading…
Reference in New Issue
Block a user