Add euclidean_distance_v2

This commit is contained in:
Jensun Ravichandran 2021-04-22 16:55:50 +02:00
parent 7d9dfc27ee
commit e2918dffed

View File

@ -43,9 +43,21 @@ def euclidean_distance(x, y):
return distances return distances
def euclidean_distance_v2(x, y):
diff = y - x.unsqueeze(1)
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
# batch diagonal. See:
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
# print(f"{diff.shape=}") # (nx, ny, ndim)
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
# print(f"{distances.shape=}") # (nx, ny)
return distances
def lpnorm_distance(x, y, p): def lpnorm_distance(x, y, p):
r""" r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
Calculates the lp-norm between :math:`\bm x` and :math:`\bm y`.
Also known as Minkowski distance. Also known as Minkowski distance.
Compute :math:`{\| \bm x - \bm y \|}_p`. Compute :math:`{\| \bm x - \bm y \|}_p`.
@ -107,26 +119,18 @@ def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
for tensor in [x, y]: for tensor in [x, y]:
if tensor.ndim != 2: if tensor.ndim != 2:
raise ValueError( raise ValueError(
"The tensor dimension must be two. You provide: tensor.ndim=" "The tensor dimension must be two. You provide: tensor.ndim=" +
+ str(tensor.ndim) str(tensor.ndim) + ".")
+ "."
)
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]): if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
raise ValueError( raise ValueError(
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]=" "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
+ str(tuple(x.shape)[1]) + str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" +
+ " and tuple(y.shape)(y)[1]=" str(tuple(y.shape)[1]) + ".")
+ str(tuple(y.shape)[1])
+ "."
)
y = torch.transpose(y) y = torch.transpose(y)
diss = ( diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
torch.sum(x ** 2, axis=1, keepdims=True) torch.sum(y**2, axis=0, keepdims=True))
- 2 * torch.dot(x, y)
+ torch.sum(y ** 2, axis=0, keepdims=True)
)
if not squared: if not squared:
if epsilon == 0: if epsilon == 0:
@ -173,19 +177,18 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
if subspaces.ndim == 2: if subspaces.ndim == 2:
# clean solution without map if the matrix_scope is global # clean solution without map if the matrix_scope is global
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot( projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
subspaces, torch.transpose(subspaces) subspaces, torch.transpose(subspaces))
)
projected_signals = torch.dot(signals, projectors) projected_signals = torch.dot(signals, projectors)
projected_protos = torch.dot(protos, projectors) projected_protos = torch.dot(protos, projectors)
diss = euclidean_distance_matrix( diss = euclidean_distance_matrix(projected_signals,
projected_signals, projected_protos, squared=squared, epsilon=epsilon projected_protos,
) squared=squared,
epsilon=epsilon)
diss = torch.reshape( diss = torch.reshape(
diss, [signal_shape[0], signal_shape[2], proto_shape[0]] diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
)
return torch.permute(diss, [0, 2, 1]) return torch.permute(diss, [0, 2, 1])
@ -193,21 +196,18 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
# no solution without map possible --> memory efficient but slow! # no solution without map possible --> memory efficient but slow!
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm( projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
subspaces, subspaces subspaces,
) # K.batch_dot(subspaces, subspaces, [2, 2]) subspaces) # K.batch_dot(subspaces, subspaces, [2, 2])
projected_protos = ( projected_protos = (protos @ subspaces
protos @ subspaces
).T # K.batch_dot(projectors, protos, [1, 1])) ).T # K.batch_dot(projectors, protos, [1, 1]))
def projected_norm(projector): def projected_norm(projector):
return torch.sum(torch.dot(signals, projector)**2, axis=1) return torch.sum(torch.dot(signals, projector)**2, axis=1)
diss = ( diss = (torch.transpose(map(projected_norm, projectors)) -
torch.transpose(map(projected_norm, projectors)) 2 * torch.dot(signals, projected_protos) +
- 2 * torch.dot(signals, projected_protos) torch.sum(projected_protos**2, axis=0, keepdims=True))
+ torch.sum(projected_protos ** 2, axis=0, keepdims=True)
)
if not squared: if not squared:
if epsilon == 0: if epsilon == 0:
@ -216,8 +216,7 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
diss = torch.sqrt(torch.max(diss, epsilon)) diss = torch.sqrt(torch.max(diss, epsilon))
diss = torch.reshape( diss = torch.reshape(
diss, [signal_shape[0], signal_shape[2], proto_shape[0]] diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
)
return torch.permute(diss, [0, 2, 1]) return torch.permute(diss, [0, 2, 1])
@ -233,12 +232,12 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
# Scope: Tangentspace Projections # Scope: Tangentspace Projections
diff = torch.reshape( diff = torch.reshape(
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1) diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
)
projected_diff = diff @ projectors projected_diff = diff @ projectors
projected_diff = torch.reshape( projected_diff = torch.reshape(
projected_diff, projected_diff,
(signal_shape[0], signal_shape[2], signal_shape[1]) + signal_shape[3:], (signal_shape[0], signal_shape[2], signal_shape[1]) +
signal_shape[3:],
) )
diss = torch.norm(projected_diff, 2, dim=-1) diss = torch.norm(projected_diff, 2, dim=-1)
@ -251,13 +250,13 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
# Scope: Tangentspace Projections # Scope: Tangentspace Projections
diff = torch.reshape( diff = torch.reshape(
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1) diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
)
diff = diff.permute([1, 0, 2]) diff = diff.permute([1, 0, 2])
projected_diff = torch.bmm(diff, projectors) projected_diff = torch.bmm(diff, projectors)
projected_diff = torch.reshape( projected_diff = torch.reshape(
projected_diff, projected_diff,
(signal_shape[1], signal_shape[0], signal_shape[2]) + signal_shape[3:], (signal_shape[1], signal_shape[0], signal_shape[2]) +
signal_shape[3:],
) )
diss = torch.norm(projected_diff, 2, dim=-1) diss = torch.norm(projected_diff, 2, dim=-1)