Add euclidean_distance_v2
This commit is contained in:
parent
7d9dfc27ee
commit
e2918dffed
@ -43,9 +43,21 @@ def euclidean_distance(x, y):
|
|||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_distance_v2(x, y):
|
||||||
|
diff = y - x.unsqueeze(1)
|
||||||
|
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||||
|
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||||
|
# batch diagonal. See:
|
||||||
|
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||||
|
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||||
|
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
||||||
|
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
||||||
|
# print(f"{distances.shape=}") # (nx, ny)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
def lpnorm_distance(x, y, p):
|
def lpnorm_distance(x, y, p):
|
||||||
r"""
|
r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
||||||
Calculates the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
|
||||||
Also known as Minkowski distance.
|
Also known as Minkowski distance.
|
||||||
|
|
||||||
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
||||||
@ -88,7 +100,7 @@ def lomega_distance(x, y, omegas):
|
|||||||
projected_y = torch.diagonal(y @ omegas).T
|
projected_y = torch.diagonal(y @ omegas).T
|
||||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||||
batchwise_difference = expanded_y - projected_x
|
batchwise_difference = expanded_y - projected_x
|
||||||
differences_squared = batchwise_difference ** 2
|
differences_squared = batchwise_difference**2
|
||||||
distances = torch.sum(differences_squared, dim=2)
|
distances = torch.sum(differences_squared, dim=2)
|
||||||
distances = distances.permute(1, 0)
|
distances = distances.permute(1, 0)
|
||||||
return distances
|
return distances
|
||||||
@ -107,26 +119,18 @@ def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
|||||||
for tensor in [x, y]:
|
for tensor in [x, y]:
|
||||||
if tensor.ndim != 2:
|
if tensor.ndim != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The tensor dimension must be two. You provide: tensor.ndim="
|
"The tensor dimension must be two. You provide: tensor.ndim=" +
|
||||||
+ str(tensor.ndim)
|
str(tensor.ndim) + ".")
|
||||||
+ "."
|
|
||||||
)
|
|
||||||
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
||||||
+ str(tuple(x.shape)[1])
|
+ str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" +
|
||||||
+ " and tuple(y.shape)(y)[1]="
|
str(tuple(y.shape)[1]) + ".")
|
||||||
+ str(tuple(y.shape)[1])
|
|
||||||
+ "."
|
|
||||||
)
|
|
||||||
|
|
||||||
y = torch.transpose(y)
|
y = torch.transpose(y)
|
||||||
|
|
||||||
diss = (
|
diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
|
||||||
torch.sum(x ** 2, axis=1, keepdims=True)
|
torch.sum(y**2, axis=0, keepdims=True))
|
||||||
- 2 * torch.dot(x, y)
|
|
||||||
+ torch.sum(y ** 2, axis=0, keepdims=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not squared:
|
if not squared:
|
||||||
if epsilon == 0:
|
if epsilon == 0:
|
||||||
@ -173,19 +177,18 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
if subspaces.ndim == 2:
|
if subspaces.ndim == 2:
|
||||||
# clean solution without map if the matrix_scope is global
|
# clean solution without map if the matrix_scope is global
|
||||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||||
subspaces, torch.transpose(subspaces)
|
subspaces, torch.transpose(subspaces))
|
||||||
)
|
|
||||||
|
|
||||||
projected_signals = torch.dot(signals, projectors)
|
projected_signals = torch.dot(signals, projectors)
|
||||||
projected_protos = torch.dot(protos, projectors)
|
projected_protos = torch.dot(protos, projectors)
|
||||||
|
|
||||||
diss = euclidean_distance_matrix(
|
diss = euclidean_distance_matrix(projected_signals,
|
||||||
projected_signals, projected_protos, squared=squared, epsilon=epsilon
|
projected_protos,
|
||||||
)
|
squared=squared,
|
||||||
|
epsilon=epsilon)
|
||||||
|
|
||||||
diss = torch.reshape(
|
diss = torch.reshape(
|
||||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
|
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||||
)
|
|
||||||
|
|
||||||
return torch.permute(diss, [0, 2, 1])
|
return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
@ -193,21 +196,18 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# no solution without map possible --> memory efficient but slow!
|
# no solution without map possible --> memory efficient but slow!
|
||||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||||
subspaces, subspaces
|
subspaces,
|
||||||
) # K.batch_dot(subspaces, subspaces, [2, 2])
|
subspaces) # K.batch_dot(subspaces, subspaces, [2, 2])
|
||||||
|
|
||||||
projected_protos = (
|
projected_protos = (protos @ subspaces
|
||||||
protos @ subspaces
|
).T # K.batch_dot(projectors, protos, [1, 1]))
|
||||||
).T # K.batch_dot(projectors, protos, [1, 1]))
|
|
||||||
|
|
||||||
def projected_norm(projector):
|
def projected_norm(projector):
|
||||||
return torch.sum(torch.dot(signals, projector) ** 2, axis=1)
|
return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
||||||
|
|
||||||
diss = (
|
diss = (torch.transpose(map(projected_norm, projectors)) -
|
||||||
torch.transpose(map(projected_norm, projectors))
|
2 * torch.dot(signals, projected_protos) +
|
||||||
- 2 * torch.dot(signals, projected_protos)
|
torch.sum(projected_protos**2, axis=0, keepdims=True))
|
||||||
+ torch.sum(projected_protos ** 2, axis=0, keepdims=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not squared:
|
if not squared:
|
||||||
if epsilon == 0:
|
if epsilon == 0:
|
||||||
@ -216,8 +216,7 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||||
|
|
||||||
diss = torch.reshape(
|
diss = torch.reshape(
|
||||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
|
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||||
)
|
|
||||||
|
|
||||||
return torch.permute(diss, [0, 2, 1])
|
return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
@ -233,12 +232,12 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# Scope: Tangentspace Projections
|
# Scope: Tangentspace Projections
|
||||||
diff = torch.reshape(
|
diff = torch.reshape(
|
||||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
|
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||||
)
|
|
||||||
projected_diff = diff @ projectors
|
projected_diff = diff @ projectors
|
||||||
projected_diff = torch.reshape(
|
projected_diff = torch.reshape(
|
||||||
projected_diff,
|
projected_diff,
|
||||||
(signal_shape[0], signal_shape[2], signal_shape[1]) + signal_shape[3:],
|
(signal_shape[0], signal_shape[2], signal_shape[1]) +
|
||||||
|
signal_shape[3:],
|
||||||
)
|
)
|
||||||
|
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
@ -251,13 +250,13 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# Scope: Tangentspace Projections
|
# Scope: Tangentspace Projections
|
||||||
diff = torch.reshape(
|
diff = torch.reshape(
|
||||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
|
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||||
)
|
|
||||||
diff = diff.permute([1, 0, 2])
|
diff = diff.permute([1, 0, 2])
|
||||||
projected_diff = torch.bmm(diff, projectors)
|
projected_diff = torch.bmm(diff, projectors)
|
||||||
projected_diff = torch.reshape(
|
projected_diff = torch.reshape(
|
||||||
projected_diff,
|
projected_diff,
|
||||||
(signal_shape[1], signal_shape[0], signal_shape[2]) + signal_shape[3:],
|
(signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||||
|
signal_shape[3:],
|
||||||
)
|
)
|
||||||
|
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user