Temporarily remove tangent distance
This commit is contained in:
		@@ -1,15 +1,7 @@
 | 
				
			|||||||
"""ProtoTorch distances"""
 | 
					"""ProtoTorch distances"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# from prototorch.functions.helper import (
 | 
					 | 
				
			||||||
#     _check_shapes,
 | 
					 | 
				
			||||||
#     _int_and_mixed_shape,
 | 
					 | 
				
			||||||
#     equal_int_shape,
 | 
					 | 
				
			||||||
#     get_flat,
 | 
					 | 
				
			||||||
# )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def squared_euclidean_distance(x, y):
 | 
					def squared_euclidean_distance(x, y):
 | 
				
			||||||
    r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
 | 
					    r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
 | 
				
			||||||
@@ -102,160 +94,5 @@ def lomega_distance(x, y, omegas):
 | 
				
			|||||||
    return distances
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
 | 
					 | 
				
			||||||
#     r"""Computes an euclidean distances matrix given two distinct vectors.
 | 
					 | 
				
			||||||
#     last dimension must be the vector dimension!
 | 
					 | 
				
			||||||
#     compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     - ``x.shape = (number_of_x_vectors, vector_dim)``
 | 
					 | 
				
			||||||
#     - ``y.shape = (number_of_y_vectors, vector_dim)``
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
 | 
					 | 
				
			||||||
#     """
 | 
					 | 
				
			||||||
#     for tensor in [x, y]:
 | 
					 | 
				
			||||||
#         if tensor.ndim != 2:
 | 
					 | 
				
			||||||
#             raise ValueError(
 | 
					 | 
				
			||||||
#                 "The tensor dimension must be two. You provide: tensor.ndim=" +
 | 
					 | 
				
			||||||
#                 str(tensor.ndim) + ".")
 | 
					 | 
				
			||||||
#     if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
 | 
					 | 
				
			||||||
#         raise ValueError(
 | 
					 | 
				
			||||||
#             "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
 | 
					 | 
				
			||||||
#             + str(tuple(x.shape)[1]) + " and  tuple(y.shape)(y)[1]=" +
 | 
					 | 
				
			||||||
#             str(tuple(y.shape)[1]) + ".")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     y = torch.transpose(y)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
 | 
					 | 
				
			||||||
#             torch.sum(y**2, axis=0, keepdims=True))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     if not squared:
 | 
					 | 
				
			||||||
#         if epsilon == 0:
 | 
					 | 
				
			||||||
#             diss = torch.sqrt(diss)
 | 
					 | 
				
			||||||
#         else:
 | 
					 | 
				
			||||||
#             diss = torch.sqrt(torch.max(diss, epsilon))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     return diss
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
 | 
					 | 
				
			||||||
#     r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     For more info about Tangen distances see
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     DOI:10.1109/IJCNN.2016.7727534.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     The subspaces is always assumed as transposed and must be orthogonal!
 | 
					 | 
				
			||||||
#     For local non sparse signals subspaces must be provided!
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
 | 
					 | 
				
			||||||
#     - shape(protos): proto_number x dim1 x dim2 x ... x dimN
 | 
					 | 
				
			||||||
#     - shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN)  x prod(projected_atom_shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     subspace should be orthogonalized
 | 
					 | 
				
			||||||
#     Pytorch implementation of Sascha Saralajew's tensorflow code.
 | 
					 | 
				
			||||||
#     Translation by Christoph Raab
 | 
					 | 
				
			||||||
#     """
 | 
					 | 
				
			||||||
#     signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
 | 
					 | 
				
			||||||
#     proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
 | 
					 | 
				
			||||||
#     subspace_int_shape = tuple(subspaces.shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     # check if the shapes are correct
 | 
					 | 
				
			||||||
#     _check_shapes(signal_int_shape, proto_int_shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     atom_axes = list(range(3, len(signal_int_shape)))
 | 
					 | 
				
			||||||
#     # for sparse signals, we use the memory efficient implementation
 | 
					 | 
				
			||||||
#     if signal_int_shape[1] == 1:
 | 
					 | 
				
			||||||
#         signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#         if len(atom_axes) > 1:
 | 
					 | 
				
			||||||
#             protos = torch.reshape(protos, [proto_shape[0], -1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#         if subspaces.ndim == 2:
 | 
					 | 
				
			||||||
#             # clean solution without map if the matrix_scope is global
 | 
					 | 
				
			||||||
#             projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
 | 
					 | 
				
			||||||
#                 subspaces, torch.transpose(subspaces))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             projected_signals = torch.dot(signals, projectors)
 | 
					 | 
				
			||||||
#             projected_protos = torch.dot(protos, projectors)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             diss = euclidean_distance_matrix(projected_signals,
 | 
					 | 
				
			||||||
#                                              projected_protos,
 | 
					 | 
				
			||||||
#                                              squared=squared,
 | 
					 | 
				
			||||||
#                                              epsilon=epsilon)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             diss = torch.reshape(
 | 
					 | 
				
			||||||
#                 diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             return torch.permute(diss, [0, 2, 1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#         else:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             # no solution without map possible --> memory efficient but slow!
 | 
					 | 
				
			||||||
#             projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
 | 
					 | 
				
			||||||
#                 subspaces,
 | 
					 | 
				
			||||||
#                 subspaces)  # K.batch_dot(subspaces, subspaces, [2, 2])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             projected_protos = (protos @ subspaces
 | 
					 | 
				
			||||||
#                                 ).T  # K.batch_dot(projectors, protos, [1, 1]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             def projected_norm(projector):
 | 
					 | 
				
			||||||
#                 return torch.sum(torch.dot(signals, projector)**2, axis=1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             diss = (torch.transpose(map(projected_norm, projectors)) -
 | 
					 | 
				
			||||||
#                     2 * torch.dot(signals, projected_protos) +
 | 
					 | 
				
			||||||
#                     torch.sum(projected_protos**2, axis=0, keepdims=True))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             if not squared:
 | 
					 | 
				
			||||||
#                 if epsilon == 0:
 | 
					 | 
				
			||||||
#                     diss = torch.sqrt(diss)
 | 
					 | 
				
			||||||
#                 else:
 | 
					 | 
				
			||||||
#                     diss = torch.sqrt(torch.max(diss, epsilon))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             diss = torch.reshape(
 | 
					 | 
				
			||||||
#                 diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             return torch.permute(diss, [0, 2, 1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#     else:
 | 
					 | 
				
			||||||
#         signals = signals.permute([0, 2, 1] + atom_axes)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#         diff = signals - protos
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#         # global tangent space
 | 
					 | 
				
			||||||
#         if subspaces.ndim == 2:
 | 
					 | 
				
			||||||
#             # Scope Projectors
 | 
					 | 
				
			||||||
#             projectors = subspaces  #
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             # Scope: Tangentspace Projections
 | 
					 | 
				
			||||||
#             diff = torch.reshape(
 | 
					 | 
				
			||||||
#                 diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
 | 
					 | 
				
			||||||
#             projected_diff = diff @ projectors
 | 
					 | 
				
			||||||
#             projected_diff = torch.reshape(
 | 
					 | 
				
			||||||
#                 projected_diff,
 | 
					 | 
				
			||||||
#                 (signal_shape[0], signal_shape[2], signal_shape[1]) +
 | 
					 | 
				
			||||||
#                 signal_shape[3:],
 | 
					 | 
				
			||||||
#             )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             diss = torch.norm(projected_diff, 2, dim=-1)
 | 
					 | 
				
			||||||
#             return diss.permute([0, 2, 1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#         # local tangent spaces
 | 
					 | 
				
			||||||
#         else:
 | 
					 | 
				
			||||||
#             # Scope: Calculate Projectors
 | 
					 | 
				
			||||||
#             projectors = subspaces
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             # Scope: Tangentspace Projections
 | 
					 | 
				
			||||||
#             diff = torch.reshape(
 | 
					 | 
				
			||||||
#                 diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
 | 
					 | 
				
			||||||
#             diff = diff.permute([1, 0, 2])
 | 
					 | 
				
			||||||
#             projected_diff = torch.bmm(diff, projectors)
 | 
					 | 
				
			||||||
#             projected_diff = torch.reshape(
 | 
					 | 
				
			||||||
#                 projected_diff,
 | 
					 | 
				
			||||||
#                 (signal_shape[1], signal_shape[0], signal_shape[2]) +
 | 
					 | 
				
			||||||
#                 signal_shape[3:],
 | 
					 | 
				
			||||||
#             )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#             diss = torch.norm(projected_diff, 2, dim=-1)
 | 
					 | 
				
			||||||
#             return diss.permute([1, 0, 2]).squeeze(-1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Aliases
 | 
					# Aliases
 | 
				
			||||||
sed = squared_euclidean_distance
 | 
					sed = squared_euclidean_distance
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user