From d26a626677e6350fbc83f0cff9a821313dd0092d Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 20:48:39 +0200 Subject: [PATCH] Temporarily remove tangent distance --- prototorch/core/distances.py | 163 ----------------------------------- 1 file changed, 163 deletions(-) diff --git a/prototorch/core/distances.py b/prototorch/core/distances.py index 0782769..c19a8dc 100644 --- a/prototorch/core/distances.py +++ b/prototorch/core/distances.py @@ -1,15 +1,7 @@ """ProtoTorch distances""" -import numpy as np import torch -# from prototorch.functions.helper import ( -# _check_shapes, -# _int_and_mixed_shape, -# equal_int_shape, -# get_flat, -# ) - def squared_euclidean_distance(x, y): r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`. @@ -102,160 +94,5 @@ def lomega_distance(x, y, omegas): return distances -# def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10): -# r"""Computes an euclidean distances matrix given two distinct vectors. -# last dimension must be the vector dimension! -# compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction! - -# - ``x.shape = (number_of_x_vectors, vector_dim)`` -# - ``y.shape = (number_of_y_vectors, vector_dim)`` - -# output: matrix of distances (number_of_x_vectors, number_of_y_vectors) -# """ -# for tensor in [x, y]: -# if tensor.ndim != 2: -# raise ValueError( -# "The tensor dimension must be two. You provide: tensor.ndim=" + -# str(tensor.ndim) + ".") -# if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]): -# raise ValueError( -# "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]=" -# + str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" + -# str(tuple(y.shape)[1]) + ".") - -# y = torch.transpose(y) - -# diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) + -# torch.sum(y**2, axis=0, keepdims=True)) - -# if not squared: -# if epsilon == 0: -# diss = torch.sqrt(diss) -# else: -# diss = torch.sqrt(torch.max(diss, epsilon)) - -# return diss - -# def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10): -# r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews - -# For more info about Tangen distances see - -# DOI:10.1109/IJCNN.2016.7727534. - -# The subspaces is always assumed as transposed and must be orthogonal! -# For local non sparse signals subspaces must be provided! - -# - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN -# - shape(protos): proto_number x dim1 x dim2 x ... x dimN -# - shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape) - -# subspace should be orthogonalized -# Pytorch implementation of Sascha Saralajew's tensorflow code. -# Translation by Christoph Raab -# """ -# signal_shape, signal_int_shape = _int_and_mixed_shape(signals) -# proto_shape, proto_int_shape = _int_and_mixed_shape(protos) -# subspace_int_shape = tuple(subspaces.shape) - -# # check if the shapes are correct -# _check_shapes(signal_int_shape, proto_int_shape) - -# atom_axes = list(range(3, len(signal_int_shape))) -# # for sparse signals, we use the memory efficient implementation -# if signal_int_shape[1] == 1: -# signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])]) - -# if len(atom_axes) > 1: -# protos = torch.reshape(protos, [proto_shape[0], -1]) - -# if subspaces.ndim == 2: -# # clean solution without map if the matrix_scope is global -# projectors = torch.eye(subspace_int_shape[-2]) - torch.dot( -# subspaces, torch.transpose(subspaces)) - -# projected_signals = torch.dot(signals, projectors) -# projected_protos = torch.dot(protos, projectors) - -# diss = euclidean_distance_matrix(projected_signals, -# projected_protos, -# squared=squared, -# epsilon=epsilon) - -# diss = torch.reshape( -# diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) - -# return torch.permute(diss, [0, 2, 1]) - -# else: - -# # no solution without map possible --> memory efficient but slow! -# projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm( -# subspaces, -# subspaces) # K.batch_dot(subspaces, subspaces, [2, 2]) - -# projected_protos = (protos @ subspaces -# ).T # K.batch_dot(projectors, protos, [1, 1])) - -# def projected_norm(projector): -# return torch.sum(torch.dot(signals, projector)**2, axis=1) - -# diss = (torch.transpose(map(projected_norm, projectors)) - -# 2 * torch.dot(signals, projected_protos) + -# torch.sum(projected_protos**2, axis=0, keepdims=True)) - -# if not squared: -# if epsilon == 0: -# diss = torch.sqrt(diss) -# else: -# diss = torch.sqrt(torch.max(diss, epsilon)) - -# diss = torch.reshape( -# diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) - -# return torch.permute(diss, [0, 2, 1]) - -# else: -# signals = signals.permute([0, 2, 1] + atom_axes) - -# diff = signals - protos - -# # global tangent space -# if subspaces.ndim == 2: -# # Scope Projectors -# projectors = subspaces # - -# # Scope: Tangentspace Projections -# diff = torch.reshape( -# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) -# projected_diff = diff @ projectors -# projected_diff = torch.reshape( -# projected_diff, -# (signal_shape[0], signal_shape[2], signal_shape[1]) + -# signal_shape[3:], -# ) - -# diss = torch.norm(projected_diff, 2, dim=-1) -# return diss.permute([0, 2, 1]) - -# # local tangent spaces -# else: -# # Scope: Calculate Projectors -# projectors = subspaces - -# # Scope: Tangentspace Projections -# diff = torch.reshape( -# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) -# diff = diff.permute([1, 0, 2]) -# projected_diff = torch.bmm(diff, projectors) -# projected_diff = torch.reshape( -# projected_diff, -# (signal_shape[1], signal_shape[0], signal_shape[2]) + -# signal_shape[3:], -# ) - -# diss = torch.norm(projected_diff, 2, dim=-1) -# return diss.permute([1, 0, 2]).squeeze(-1) - # Aliases sed = squared_euclidean_distance