Temporarily remove tangent distance
This commit is contained in:
parent
b4ad870b7c
commit
d26a626677
@ -1,15 +1,7 @@
|
||||
"""ProtoTorch distances"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# from prototorch.functions.helper import (
|
||||
# _check_shapes,
|
||||
# _int_and_mixed_shape,
|
||||
# equal_int_shape,
|
||||
# get_flat,
|
||||
# )
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
|
||||
@ -102,160 +94,5 @@ def lomega_distance(x, y, omegas):
|
||||
return distances
|
||||
|
||||
|
||||
# def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
||||
# r"""Computes an euclidean distances matrix given two distinct vectors.
|
||||
# last dimension must be the vector dimension!
|
||||
# compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
||||
|
||||
# - ``x.shape = (number_of_x_vectors, vector_dim)``
|
||||
# - ``y.shape = (number_of_y_vectors, vector_dim)``
|
||||
|
||||
# output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
||||
# """
|
||||
# for tensor in [x, y]:
|
||||
# if tensor.ndim != 2:
|
||||
# raise ValueError(
|
||||
# "The tensor dimension must be two. You provide: tensor.ndim=" +
|
||||
# str(tensor.ndim) + ".")
|
||||
# if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||
# raise ValueError(
|
||||
# "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
||||
# + str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" +
|
||||
# str(tuple(y.shape)[1]) + ".")
|
||||
|
||||
# y = torch.transpose(y)
|
||||
|
||||
# diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
|
||||
# torch.sum(y**2, axis=0, keepdims=True))
|
||||
|
||||
# if not squared:
|
||||
# if epsilon == 0:
|
||||
# diss = torch.sqrt(diss)
|
||||
# else:
|
||||
# diss = torch.sqrt(torch.max(diss, epsilon))
|
||||
|
||||
# return diss
|
||||
|
||||
# def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
||||
# r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
|
||||
|
||||
# For more info about Tangen distances see
|
||||
|
||||
# DOI:10.1109/IJCNN.2016.7727534.
|
||||
|
||||
# The subspaces is always assumed as transposed and must be orthogonal!
|
||||
# For local non sparse signals subspaces must be provided!
|
||||
|
||||
# - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
# - shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
# - shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||
|
||||
# subspace should be orthogonalized
|
||||
# Pytorch implementation of Sascha Saralajew's tensorflow code.
|
||||
# Translation by Christoph Raab
|
||||
# """
|
||||
# signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||
# proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
||||
# subspace_int_shape = tuple(subspaces.shape)
|
||||
|
||||
# # check if the shapes are correct
|
||||
# _check_shapes(signal_int_shape, proto_int_shape)
|
||||
|
||||
# atom_axes = list(range(3, len(signal_int_shape)))
|
||||
# # for sparse signals, we use the memory efficient implementation
|
||||
# if signal_int_shape[1] == 1:
|
||||
# signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
|
||||
|
||||
# if len(atom_axes) > 1:
|
||||
# protos = torch.reshape(protos, [proto_shape[0], -1])
|
||||
|
||||
# if subspaces.ndim == 2:
|
||||
# # clean solution without map if the matrix_scope is global
|
||||
# projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||
# subspaces, torch.transpose(subspaces))
|
||||
|
||||
# projected_signals = torch.dot(signals, projectors)
|
||||
# projected_protos = torch.dot(protos, projectors)
|
||||
|
||||
# diss = euclidean_distance_matrix(projected_signals,
|
||||
# projected_protos,
|
||||
# squared=squared,
|
||||
# epsilon=epsilon)
|
||||
|
||||
# diss = torch.reshape(
|
||||
# diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||
|
||||
# return torch.permute(diss, [0, 2, 1])
|
||||
|
||||
# else:
|
||||
|
||||
# # no solution without map possible --> memory efficient but slow!
|
||||
# projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||
# subspaces,
|
||||
# subspaces) # K.batch_dot(subspaces, subspaces, [2, 2])
|
||||
|
||||
# projected_protos = (protos @ subspaces
|
||||
# ).T # K.batch_dot(projectors, protos, [1, 1]))
|
||||
|
||||
# def projected_norm(projector):
|
||||
# return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
||||
|
||||
# diss = (torch.transpose(map(projected_norm, projectors)) -
|
||||
# 2 * torch.dot(signals, projected_protos) +
|
||||
# torch.sum(projected_protos**2, axis=0, keepdims=True))
|
||||
|
||||
# if not squared:
|
||||
# if epsilon == 0:
|
||||
# diss = torch.sqrt(diss)
|
||||
# else:
|
||||
# diss = torch.sqrt(torch.max(diss, epsilon))
|
||||
|
||||
# diss = torch.reshape(
|
||||
# diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||
|
||||
# return torch.permute(diss, [0, 2, 1])
|
||||
|
||||
# else:
|
||||
# signals = signals.permute([0, 2, 1] + atom_axes)
|
||||
|
||||
# diff = signals - protos
|
||||
|
||||
# # global tangent space
|
||||
# if subspaces.ndim == 2:
|
||||
# # Scope Projectors
|
||||
# projectors = subspaces #
|
||||
|
||||
# # Scope: Tangentspace Projections
|
||||
# diff = torch.reshape(
|
||||
# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||
# projected_diff = diff @ projectors
|
||||
# projected_diff = torch.reshape(
|
||||
# projected_diff,
|
||||
# (signal_shape[0], signal_shape[2], signal_shape[1]) +
|
||||
# signal_shape[3:],
|
||||
# )
|
||||
|
||||
# diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
# return diss.permute([0, 2, 1])
|
||||
|
||||
# # local tangent spaces
|
||||
# else:
|
||||
# # Scope: Calculate Projectors
|
||||
# projectors = subspaces
|
||||
|
||||
# # Scope: Tangentspace Projections
|
||||
# diff = torch.reshape(
|
||||
# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||
# diff = diff.permute([1, 0, 2])
|
||||
# projected_diff = torch.bmm(diff, projectors)
|
||||
# projected_diff = torch.reshape(
|
||||
# projected_diff,
|
||||
# (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||
# signal_shape[3:],
|
||||
# )
|
||||
|
||||
# diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
# return diss.permute([1, 0, 2]).squeeze(-1)
|
||||
|
||||
# Aliases
|
||||
sed = squared_euclidean_distance
|
||||
|
Loading…
Reference in New Issue
Block a user