Remove prototorch/functions and prototorch/modules
This commit is contained in:
parent
38244f6691
commit
b4ad870b7c
@ -1,5 +0,0 @@
|
|||||||
"""ProtoTorch functions."""
|
|
||||||
|
|
||||||
from .activations import identity, sigmoid_beta, swish_beta
|
|
||||||
from .competitions import knnc, wtac
|
|
||||||
from .pooling import *
|
|
@ -1,62 +0,0 @@
|
|||||||
"""ProtoTorch activation functions."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
ACTIVATIONS = dict()
|
|
||||||
|
|
||||||
|
|
||||||
def register_activation(fn):
|
|
||||||
"""Add the activation function to the registry."""
|
|
||||||
name = fn.__name__
|
|
||||||
ACTIVATIONS[name] = fn
|
|
||||||
return fn
|
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
|
||||||
def identity(x, beta=0.0):
|
|
||||||
"""Identity activation function.
|
|
||||||
|
|
||||||
Definition:
|
|
||||||
:math:`f(x) = x`
|
|
||||||
|
|
||||||
Keyword Arguments:
|
|
||||||
beta (`float`): Ignored.
|
|
||||||
"""
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
|
||||||
def sigmoid_beta(x, beta=10.0):
|
|
||||||
r"""Sigmoid activation function with scaling.
|
|
||||||
|
|
||||||
Definition:
|
|
||||||
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
|
|
||||||
|
|
||||||
Keyword Arguments:
|
|
||||||
beta (`float`): Scaling parameter :math:`\beta`
|
|
||||||
"""
|
|
||||||
out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
|
||||||
def swish_beta(x, beta=10.0):
|
|
||||||
r"""Swish activation function with scaling.
|
|
||||||
|
|
||||||
Definition:
|
|
||||||
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
|
|
||||||
|
|
||||||
Keyword Arguments:
|
|
||||||
beta (`float`): Scaling parameter :math:`\beta`
|
|
||||||
"""
|
|
||||||
out = x * sigmoid_beta(x, beta=beta)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def get_activation(funcname):
|
|
||||||
"""Deserialize the activation function."""
|
|
||||||
if callable(funcname):
|
|
||||||
return funcname
|
|
||||||
if funcname in ACTIVATIONS:
|
|
||||||
return ACTIVATIONS.get(funcname)
|
|
||||||
raise NameError(f"Activation {funcname} was not found.")
|
|
@ -1,28 +0,0 @@
|
|||||||
"""ProtoTorch competition functions."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def wtac(distances: torch.Tensor,
|
|
||||||
labels: torch.LongTensor) -> (torch.LongTensor):
|
|
||||||
"""Winner-Takes-All-Competition.
|
|
||||||
|
|
||||||
Returns the labels corresponding to the winners.
|
|
||||||
|
|
||||||
"""
|
|
||||||
winning_indices = torch.min(distances, dim=1).indices
|
|
||||||
winning_labels = labels[winning_indices].squeeze()
|
|
||||||
return winning_labels
|
|
||||||
|
|
||||||
|
|
||||||
def knnc(distances: torch.Tensor,
|
|
||||||
labels: torch.LongTensor,
|
|
||||||
k: int = 1) -> (torch.LongTensor):
|
|
||||||
"""K-Nearest-Neighbors-Competition.
|
|
||||||
|
|
||||||
Returns the labels corresponding to the winners.
|
|
||||||
|
|
||||||
"""
|
|
||||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
|
||||||
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
|
||||||
return winning_labels
|
|
@ -1,258 +0,0 @@
|
|||||||
"""ProtoTorch distance functions."""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
|
||||||
equal_int_shape, get_flat)
|
|
||||||
|
|
||||||
|
|
||||||
def squared_euclidean_distance(x, y):
|
|
||||||
r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
|
|
||||||
|
|
||||||
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
|
||||||
|
|
||||||
**Alias:**
|
|
||||||
``prototorch.functions.distances.sed``
|
|
||||||
"""
|
|
||||||
x, y = get_flat(x, y)
|
|
||||||
expanded_x = x.unsqueeze(dim=1)
|
|
||||||
batchwise_difference = y - expanded_x
|
|
||||||
differences_raised = torch.pow(batchwise_difference, 2)
|
|
||||||
distances = torch.sum(differences_raised, axis=2)
|
|
||||||
return distances
|
|
||||||
|
|
||||||
|
|
||||||
def euclidean_distance(x, y):
|
|
||||||
r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
|
|
||||||
|
|
||||||
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
|
||||||
|
|
||||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
|
||||||
:rtype: `torch.tensor`
|
|
||||||
"""
|
|
||||||
x, y = get_flat(x, y)
|
|
||||||
distances_raised = squared_euclidean_distance(x, y)
|
|
||||||
distances = torch.sqrt(distances_raised)
|
|
||||||
return distances
|
|
||||||
|
|
||||||
|
|
||||||
def euclidean_distance_v2(x, y):
|
|
||||||
x, y = get_flat(x, y)
|
|
||||||
diff = y - x.unsqueeze(1)
|
|
||||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
|
||||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
|
||||||
# batch diagonal. See:
|
|
||||||
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
|
||||||
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
|
||||||
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
|
||||||
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
|
||||||
# print(f"{distances.shape=}") # (nx, ny)
|
|
||||||
return distances
|
|
||||||
|
|
||||||
|
|
||||||
def lpnorm_distance(x, y, p):
|
|
||||||
r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
|
||||||
Also known as Minkowski distance.
|
|
||||||
|
|
||||||
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
|
||||||
|
|
||||||
Calls ``torch.cdist``
|
|
||||||
|
|
||||||
:param p: p parameter of the lp norm
|
|
||||||
"""
|
|
||||||
x, y = get_flat(x, y)
|
|
||||||
distances = torch.cdist(x, y, p=p)
|
|
||||||
return distances
|
|
||||||
|
|
||||||
|
|
||||||
def omega_distance(x, y, omega):
|
|
||||||
r"""Omega distance.
|
|
||||||
|
|
||||||
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
|
||||||
|
|
||||||
:param `torch.tensor` omega: Two dimensional matrix
|
|
||||||
"""
|
|
||||||
x, y = get_flat(x, y)
|
|
||||||
projected_x = x @ omega
|
|
||||||
projected_y = y @ omega
|
|
||||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
|
||||||
return distances
|
|
||||||
|
|
||||||
|
|
||||||
def lomega_distance(x, y, omegas):
|
|
||||||
r"""Localized Omega distance.
|
|
||||||
|
|
||||||
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
|
||||||
|
|
||||||
:param `torch.tensor` omegas: Three dimensional matrix
|
|
||||||
"""
|
|
||||||
x, y = get_flat(x, y)
|
|
||||||
projected_x = x @ omegas
|
|
||||||
projected_y = torch.diagonal(y @ omegas).T
|
|
||||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
|
||||||
batchwise_difference = expanded_y - projected_x
|
|
||||||
differences_squared = batchwise_difference**2
|
|
||||||
distances = torch.sum(differences_squared, dim=2)
|
|
||||||
distances = distances.permute(1, 0)
|
|
||||||
return distances
|
|
||||||
|
|
||||||
|
|
||||||
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
|
||||||
r"""Computes an euclidean distances matrix given two distinct vectors.
|
|
||||||
last dimension must be the vector dimension!
|
|
||||||
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
|
||||||
|
|
||||||
- ``x.shape = (number_of_x_vectors, vector_dim)``
|
|
||||||
- ``y.shape = (number_of_y_vectors, vector_dim)``
|
|
||||||
|
|
||||||
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
|
||||||
"""
|
|
||||||
for tensor in [x, y]:
|
|
||||||
if tensor.ndim != 2:
|
|
||||||
raise ValueError(
|
|
||||||
"The tensor dimension must be two. You provide: tensor.ndim=" +
|
|
||||||
str(tensor.ndim) + ".")
|
|
||||||
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
|
||||||
raise ValueError(
|
|
||||||
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
|
||||||
+ str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" +
|
|
||||||
str(tuple(y.shape)[1]) + ".")
|
|
||||||
|
|
||||||
y = torch.transpose(y)
|
|
||||||
|
|
||||||
diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
|
|
||||||
torch.sum(y**2, axis=0, keepdims=True))
|
|
||||||
|
|
||||||
if not squared:
|
|
||||||
if epsilon == 0:
|
|
||||||
diss = torch.sqrt(diss)
|
|
||||||
else:
|
|
||||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
|
||||||
|
|
||||||
return diss
|
|
||||||
|
|
||||||
|
|
||||||
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|
||||||
r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
|
|
||||||
|
|
||||||
For more info about Tangen distances see
|
|
||||||
|
|
||||||
DOI:10.1109/IJCNN.2016.7727534.
|
|
||||||
|
|
||||||
The subspaces is always assumed as transposed and must be orthogonal!
|
|
||||||
For local non sparse signals subspaces must be provided!
|
|
||||||
|
|
||||||
- shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
|
||||||
- shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
|
||||||
- shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
|
||||||
|
|
||||||
subspace should be orthogonalized
|
|
||||||
Pytorch implementation of Sascha Saralajew's tensorflow code.
|
|
||||||
Translation by Christoph Raab
|
|
||||||
"""
|
|
||||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
|
||||||
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
|
||||||
subspace_int_shape = tuple(subspaces.shape)
|
|
||||||
|
|
||||||
# check if the shapes are correct
|
|
||||||
_check_shapes(signal_int_shape, proto_int_shape)
|
|
||||||
|
|
||||||
atom_axes = list(range(3, len(signal_int_shape)))
|
|
||||||
# for sparse signals, we use the memory efficient implementation
|
|
||||||
if signal_int_shape[1] == 1:
|
|
||||||
signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
|
|
||||||
|
|
||||||
if len(atom_axes) > 1:
|
|
||||||
protos = torch.reshape(protos, [proto_shape[0], -1])
|
|
||||||
|
|
||||||
if subspaces.ndim == 2:
|
|
||||||
# clean solution without map if the matrix_scope is global
|
|
||||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
|
||||||
subspaces, torch.transpose(subspaces))
|
|
||||||
|
|
||||||
projected_signals = torch.dot(signals, projectors)
|
|
||||||
projected_protos = torch.dot(protos, projectors)
|
|
||||||
|
|
||||||
diss = euclidean_distance_matrix(projected_signals,
|
|
||||||
projected_protos,
|
|
||||||
squared=squared,
|
|
||||||
epsilon=epsilon)
|
|
||||||
|
|
||||||
diss = torch.reshape(
|
|
||||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
|
||||||
|
|
||||||
return torch.permute(diss, [0, 2, 1])
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
# no solution without map possible --> memory efficient but slow!
|
|
||||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
|
||||||
subspaces,
|
|
||||||
subspaces) # K.batch_dot(subspaces, subspaces, [2, 2])
|
|
||||||
|
|
||||||
projected_protos = (protos @ subspaces
|
|
||||||
).T # K.batch_dot(projectors, protos, [1, 1]))
|
|
||||||
|
|
||||||
def projected_norm(projector):
|
|
||||||
return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
|
||||||
|
|
||||||
diss = (torch.transpose(map(projected_norm, projectors)) -
|
|
||||||
2 * torch.dot(signals, projected_protos) +
|
|
||||||
torch.sum(projected_protos**2, axis=0, keepdims=True))
|
|
||||||
|
|
||||||
if not squared:
|
|
||||||
if epsilon == 0:
|
|
||||||
diss = torch.sqrt(diss)
|
|
||||||
else:
|
|
||||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
|
||||||
|
|
||||||
diss = torch.reshape(
|
|
||||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
|
||||||
|
|
||||||
return torch.permute(diss, [0, 2, 1])
|
|
||||||
|
|
||||||
else:
|
|
||||||
signals = signals.permute([0, 2, 1] + atom_axes)
|
|
||||||
|
|
||||||
diff = signals - protos
|
|
||||||
|
|
||||||
# global tangent space
|
|
||||||
if subspaces.ndim == 2:
|
|
||||||
# Scope Projectors
|
|
||||||
projectors = subspaces #
|
|
||||||
|
|
||||||
# Scope: Tangentspace Projections
|
|
||||||
diff = torch.reshape(
|
|
||||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
|
||||||
projected_diff = diff @ projectors
|
|
||||||
projected_diff = torch.reshape(
|
|
||||||
projected_diff,
|
|
||||||
(signal_shape[0], signal_shape[2], signal_shape[1]) +
|
|
||||||
signal_shape[3:],
|
|
||||||
)
|
|
||||||
|
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
|
||||||
return diss.permute([0, 2, 1])
|
|
||||||
|
|
||||||
# local tangent spaces
|
|
||||||
else:
|
|
||||||
# Scope: Calculate Projectors
|
|
||||||
projectors = subspaces
|
|
||||||
|
|
||||||
# Scope: Tangentspace Projections
|
|
||||||
diff = torch.reshape(
|
|
||||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
|
||||||
diff = diff.permute([1, 0, 2])
|
|
||||||
projected_diff = torch.bmm(diff, projectors)
|
|
||||||
projected_diff = torch.reshape(
|
|
||||||
projected_diff,
|
|
||||||
(signal_shape[1], signal_shape[0], signal_shape[2]) +
|
|
||||||
signal_shape[3:],
|
|
||||||
)
|
|
||||||
|
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
|
||||||
return diss.permute([1, 0, 2]).squeeze(-1)
|
|
||||||
|
|
||||||
|
|
||||||
# Aliases
|
|
||||||
sed = squared_euclidean_distance
|
|
@ -1,94 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def get_flat(*args):
|
|
||||||
rv = [x.view(x.size(0), -1) for x in args]
|
|
||||||
return rv
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
|
||||||
"""Computes the accuracy of a prototype based model.
|
|
||||||
via Winner-Takes-All rule.
|
|
||||||
Requirement:
|
|
||||||
y_pred.shape == y_true.shape
|
|
||||||
unique(y_pred) in plabels
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
idx = torch.argmin(y_pred, axis=1)
|
|
||||||
return torch.true_divide(torch.sum(y_true == plabels[idx]),
|
|
||||||
len(y_pred)) * 100
|
|
||||||
|
|
||||||
|
|
||||||
def predict_label(y_pred, plabels):
|
|
||||||
r""" Predicts labels given a prediction of a prototype based model.
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
return plabels[torch.argmin(y_pred, 1)]
|
|
||||||
|
|
||||||
|
|
||||||
def mixed_shape(inputs):
|
|
||||||
if not torch.is_tensor(inputs):
|
|
||||||
raise ValueError("Input must be a tensor.")
|
|
||||||
else:
|
|
||||||
int_shape = list(inputs.shape)
|
|
||||||
# sometimes int_shape returns mixed integer types
|
|
||||||
int_shape = [int(i) if i is not None else i for i in int_shape]
|
|
||||||
tensor_shape = inputs.shape
|
|
||||||
|
|
||||||
for i, s in enumerate(int_shape):
|
|
||||||
if s is None:
|
|
||||||
int_shape[i] = tensor_shape[i]
|
|
||||||
return tuple(int_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def equal_int_shape(shape_1, shape_2):
|
|
||||||
if not isinstance(shape_1,
|
|
||||||
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
|
||||||
raise ValueError("Input shapes must list or tuple.")
|
|
||||||
for shape in [shape_1, shape_2]:
|
|
||||||
if not all([isinstance(x, int) or x is None for x in shape]):
|
|
||||||
raise ValueError(
|
|
||||||
"Input shapes must be list or tuple of int and None values.")
|
|
||||||
|
|
||||||
if len(shape_1) != len(shape_2):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
for axis, value in enumerate(shape_1):
|
|
||||||
if value is not None and shape_2[axis] not in {value, None}:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _check_shapes(signal_int_shape, proto_int_shape):
|
|
||||||
if len(signal_int_shape) < 4:
|
|
||||||
raise ValueError(
|
|
||||||
"The number of signal dimensions must be >=4. You provide: " +
|
|
||||||
str(len(signal_int_shape)))
|
|
||||||
|
|
||||||
if len(proto_int_shape) < 2:
|
|
||||||
raise ValueError(
|
|
||||||
"The number of proto dimensions must be >=2. You provide: " +
|
|
||||||
str(len(proto_int_shape)))
|
|
||||||
|
|
||||||
if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]):
|
|
||||||
raise ValueError(
|
|
||||||
"The atom shape of signals must be equal protos. You provide: signals.shape[3:]="
|
|
||||||
+ str(signal_int_shape[3:]) + " != protos.shape[1:]=" +
|
|
||||||
str(proto_int_shape[1:]))
|
|
||||||
|
|
||||||
# not a sparse signal
|
|
||||||
if signal_int_shape[1] != 1:
|
|
||||||
if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]):
|
|
||||||
raise ValueError(
|
|
||||||
"If the signal is not sparse, the number of prototypes must be equal in signals and "
|
|
||||||
"protos. You provide: " + str(signal_int_shape[1]) + " != " +
|
|
||||||
str(proto_int_shape[0]))
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _int_and_mixed_shape(tensor):
|
|
||||||
shape = mixed_shape(tensor)
|
|
||||||
int_shape = tuple([i if isinstance(i, int) else None for i in shape])
|
|
||||||
|
|
||||||
return shape, int_shape
|
|
@ -1,107 +0,0 @@
|
|||||||
"""ProtoTorch initialization functions."""
|
|
||||||
|
|
||||||
from itertools import chain
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
INITIALIZERS = dict()
|
|
||||||
|
|
||||||
|
|
||||||
def register_initializer(function):
|
|
||||||
"""Add the initializer to the registry."""
|
|
||||||
INITIALIZERS[function.__name__] = function
|
|
||||||
return function
|
|
||||||
|
|
||||||
|
|
||||||
def labels_from(distribution, one_hot=True):
|
|
||||||
"""Takes a distribution tensor and returns a labels tensor."""
|
|
||||||
num_classes = distribution.shape[0]
|
|
||||||
llist = [[i] * n for i, n in zip(range(num_classes), distribution)]
|
|
||||||
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
|
||||||
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
|
||||||
plabels = torch.tensor(flat_llist, requires_grad=False)
|
|
||||||
if one_hot:
|
|
||||||
return torch.eye(num_classes)[plabels]
|
|
||||||
return plabels
|
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
|
||||||
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
|
||||||
num_protos = torch.sum(prototype_distribution)
|
|
||||||
protos = torch.ones(num_protos, *x_train.shape[1:])
|
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
|
||||||
return protos, plabels
|
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
|
||||||
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
|
||||||
num_protos = torch.sum(prototype_distribution)
|
|
||||||
protos = torch.zeros(num_protos, *x_train.shape[1:])
|
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
|
||||||
return protos, plabels
|
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
|
||||||
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
|
||||||
num_protos = torch.sum(prototype_distribution)
|
|
||||||
protos = torch.rand(num_protos, *x_train.shape[1:])
|
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
|
||||||
return protos, plabels
|
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
|
||||||
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
|
||||||
num_protos = torch.sum(prototype_distribution)
|
|
||||||
protos = torch.randn(num_protos, *x_train.shape[1:])
|
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
|
||||||
return protos, plabels
|
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
|
||||||
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
|
||||||
num_protos = torch.sum(prototype_distribution)
|
|
||||||
pdim = x_train.shape[1]
|
|
||||||
protos = torch.empty(num_protos, pdim)
|
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
|
||||||
for i, label in enumerate(plabels):
|
|
||||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
|
||||||
if one_hot:
|
|
||||||
num_classes = y_train.size()[1]
|
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
|
||||||
xl = x_train[matcher]
|
|
||||||
mean_xl = torch.mean(xl, dim=0)
|
|
||||||
protos[i] = mean_xl
|
|
||||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
|
||||||
return protos, plabels
|
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
|
||||||
def stratified_random(x_train,
|
|
||||||
y_train,
|
|
||||||
prototype_distribution,
|
|
||||||
one_hot=True,
|
|
||||||
epsilon=1e-7):
|
|
||||||
num_protos = torch.sum(prototype_distribution)
|
|
||||||
pdim = x_train.shape[1]
|
|
||||||
protos = torch.empty(num_protos, pdim)
|
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
|
||||||
for i, label in enumerate(plabels):
|
|
||||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
|
||||||
if one_hot:
|
|
||||||
num_classes = y_train.size()[1]
|
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
|
||||||
xl = x_train[matcher]
|
|
||||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
|
||||||
random_xl = xl[rand_index]
|
|
||||||
protos[i] = random_xl + epsilon
|
|
||||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
|
||||||
return protos, plabels
|
|
||||||
|
|
||||||
|
|
||||||
def get_initializer(funcname):
|
|
||||||
"""Deserialize the initializer."""
|
|
||||||
if callable(funcname):
|
|
||||||
return funcname
|
|
||||||
if funcname in INITIALIZERS:
|
|
||||||
return INITIALIZERS.get(funcname)
|
|
||||||
raise NameError(f"Initializer {funcname} was not found.")
|
|
@ -1,94 +0,0 @@
|
|||||||
"""ProtoTorch loss functions."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def _get_matcher(targets, labels):
|
|
||||||
"""Returns a boolean tensor."""
|
|
||||||
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
|
||||||
if labels.ndim == 2:
|
|
||||||
# if the labels are one-hot vectors
|
|
||||||
num_classes = targets.size()[1]
|
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
|
||||||
return matcher
|
|
||||||
|
|
||||||
|
|
||||||
def _get_dp_dm(distances, targets, plabels, with_indices=False):
|
|
||||||
"""Returns the d+ and d- values for a batch of distances."""
|
|
||||||
matcher = _get_matcher(targets, plabels)
|
|
||||||
not_matcher = torch.bitwise_not(matcher)
|
|
||||||
|
|
||||||
inf = torch.full_like(distances, fill_value=float("inf"))
|
|
||||||
d_matching = torch.where(matcher, distances, inf)
|
|
||||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
|
||||||
dp = torch.min(d_matching, dim=-1, keepdim=True)
|
|
||||||
dm = torch.min(d_unmatching, dim=-1, keepdim=True)
|
|
||||||
if with_indices:
|
|
||||||
return dp, dm
|
|
||||||
return dp.values, dm.values
|
|
||||||
|
|
||||||
|
|
||||||
def glvq_loss(distances, target_labels, prototype_labels):
|
|
||||||
"""GLVQ loss function with support for one-hot labels."""
|
|
||||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
|
||||||
mu = (dp - dm) / (dp + dm)
|
|
||||||
return mu
|
|
||||||
|
|
||||||
|
|
||||||
def lvq1_loss(distances, target_labels, prototype_labels):
|
|
||||||
"""LVQ1 loss function with support for one-hot labels.
|
|
||||||
|
|
||||||
See Section 4 [Sado&Yamada]
|
|
||||||
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
|
||||||
"""
|
|
||||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
|
||||||
mu = dp
|
|
||||||
mu[dp > dm] = -dm[dp > dm]
|
|
||||||
return mu
|
|
||||||
|
|
||||||
|
|
||||||
def lvq21_loss(distances, target_labels, prototype_labels):
|
|
||||||
"""LVQ2.1 loss function with support for one-hot labels.
|
|
||||||
|
|
||||||
See Section 4 [Sado&Yamada]
|
|
||||||
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
|
||||||
"""
|
|
||||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
|
||||||
mu = dp - dm
|
|
||||||
|
|
||||||
return mu
|
|
||||||
|
|
||||||
|
|
||||||
# Probabilistic
|
|
||||||
def _get_class_probabilities(probabilities, targets, prototype_labels):
|
|
||||||
# Create Label Mapping
|
|
||||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
|
||||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
|
||||||
|
|
||||||
target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist())))
|
|
||||||
|
|
||||||
whole = probabilities.sum(dim=1)
|
|
||||||
correct = probabilities[torch.arange(len(probabilities)), target_indices]
|
|
||||||
wrong = whole - correct
|
|
||||||
|
|
||||||
return whole, correct, wrong
|
|
||||||
|
|
||||||
|
|
||||||
def nllr_loss(probabilities, targets, prototype_labels):
|
|
||||||
"""Compute the Negative Log-Likelihood Ratio loss."""
|
|
||||||
_, correct, wrong = _get_class_probabilities(probabilities, targets,
|
|
||||||
prototype_labels)
|
|
||||||
|
|
||||||
likelihood = correct / wrong
|
|
||||||
log_likelihood = torch.log(likelihood)
|
|
||||||
return -1.0 * log_likelihood
|
|
||||||
|
|
||||||
|
|
||||||
def rslvq_loss(probabilities, targets, prototype_labels):
|
|
||||||
"""Compute the Robust Soft Learning Vector Quantization (RSLVQ) loss."""
|
|
||||||
whole, correct, _ = _get_class_probabilities(probabilities, targets,
|
|
||||||
prototype_labels)
|
|
||||||
|
|
||||||
likelihood = correct / whole
|
|
||||||
log_likelihood = torch.log(likelihood)
|
|
||||||
return -1.0 * log_likelihood
|
|
@ -1,35 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def orthogonalization(tensors):
|
|
||||||
r""" Orthogonalization of a given tensor via polar decomposition.
|
|
||||||
"""
|
|
||||||
u, _, v = torch.svd(tensors, compute_uv=True)
|
|
||||||
u_shape = tuple(list(u.shape))
|
|
||||||
v_shape = tuple(list(v.shape))
|
|
||||||
|
|
||||||
# reshape to (num x N x M)
|
|
||||||
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
|
|
||||||
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
|
|
||||||
|
|
||||||
out = u @ v.permute([0, 2, 1])
|
|
||||||
|
|
||||||
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def trace_normalization(tensors):
|
|
||||||
r""" Trace normalization
|
|
||||||
"""
|
|
||||||
epsilon = torch.tensor([1e-10], dtype=torch.float64)
|
|
||||||
# Scope trace_normalization
|
|
||||||
constant = torch.trace(tensors)
|
|
||||||
|
|
||||||
if epsilon != 0:
|
|
||||||
constant = torch.max(constant, epsilon)
|
|
||||||
|
|
||||||
return tensors / constant
|
|
@ -1,80 +0,0 @@
|
|||||||
"""ProtoTorch pooling functions."""
|
|
||||||
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def stratify_with(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor,
|
|
||||||
fn: Callable,
|
|
||||||
fill_value: float = 0.0) -> (torch.Tensor):
|
|
||||||
"""Apply an arbitrary stratification strategy on the columns on `values`.
|
|
||||||
|
|
||||||
The outputs correspond to sorted labels.
|
|
||||||
"""
|
|
||||||
clabels = torch.unique(labels, dim=0, sorted=True)
|
|
||||||
num_classes = clabels.size()[0]
|
|
||||||
if values.size()[1] == num_classes:
|
|
||||||
# skip if stratification is trivial
|
|
||||||
return values
|
|
||||||
batch_size = values.size()[0]
|
|
||||||
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
|
||||||
filler = torch.full_like(values.T, fill_value=fill_value)
|
|
||||||
for i, cl in enumerate(clabels):
|
|
||||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
|
||||||
if labels.ndim == 2:
|
|
||||||
# if the labels are one-hot vectors
|
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
|
||||||
cdists = torch.where(matcher, values.T, filler).T
|
|
||||||
winning_values[i] = fn(cdists)
|
|
||||||
if labels.ndim == 2:
|
|
||||||
# Transpose to return with `batch_size` first and
|
|
||||||
# reverse the columns to fix the ordering of the classes
|
|
||||||
return torch.flip(winning_values.T, dims=(1, ))
|
|
||||||
|
|
||||||
return winning_values.T # return with `batch_size` first
|
|
||||||
|
|
||||||
|
|
||||||
def stratified_sum_pooling(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor) -> (torch.Tensor):
|
|
||||||
"""Group-wise sum."""
|
|
||||||
winning_values = stratify_with(
|
|
||||||
values,
|
|
||||||
labels,
|
|
||||||
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
|
|
||||||
fill_value=0.0)
|
|
||||||
return winning_values
|
|
||||||
|
|
||||||
|
|
||||||
def stratified_min_pooling(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor) -> (torch.Tensor):
|
|
||||||
"""Group-wise minimum."""
|
|
||||||
winning_values = stratify_with(
|
|
||||||
values,
|
|
||||||
labels,
|
|
||||||
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
|
|
||||||
fill_value=float("inf"))
|
|
||||||
return winning_values
|
|
||||||
|
|
||||||
|
|
||||||
def stratified_max_pooling(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor) -> (torch.Tensor):
|
|
||||||
"""Group-wise maximum."""
|
|
||||||
winning_values = stratify_with(
|
|
||||||
values,
|
|
||||||
labels,
|
|
||||||
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
|
|
||||||
fill_value=-1.0 * float("inf"))
|
|
||||||
return winning_values
|
|
||||||
|
|
||||||
|
|
||||||
def stratified_prod_pooling(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor) -> (torch.Tensor):
|
|
||||||
"""Group-wise maximum."""
|
|
||||||
winning_values = stratify_with(
|
|
||||||
values,
|
|
||||||
labels,
|
|
||||||
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
|
|
||||||
fill_value=1.0)
|
|
||||||
return winning_values
|
|
@ -1,18 +0,0 @@
|
|||||||
"""ProtoTorch similarity functions."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(x, y):
|
|
||||||
"""Compute the cosine similarity between :math:`x` and :math:`y`.
|
|
||||||
|
|
||||||
Expected dimension of x is 2.
|
|
||||||
Expected dimension of y is 2.
|
|
||||||
"""
|
|
||||||
norm_x = x.pow(2).sum(1).sqrt()
|
|
||||||
norm_y = y.pow(2).sum(1).sqrt()
|
|
||||||
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
|
||||||
epsilon = torch.finfo(norm_mat.dtype).eps
|
|
||||||
norm_mat.clamp_(min=epsilon)
|
|
||||||
similarities = (x @ y.T) / norm_mat
|
|
||||||
return similarities
|
|
@ -1,32 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
# Functions
|
|
||||||
def gaussian(distances, variance):
|
|
||||||
return torch.exp(-(distances * distances) / (2 * variance))
|
|
||||||
|
|
||||||
|
|
||||||
def rank_scaled_gaussian(distances, lambd):
|
|
||||||
order = torch.argsort(distances, dim=1)
|
|
||||||
ranks = torch.argsort(order, dim=1)
|
|
||||||
|
|
||||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
|
||||||
|
|
||||||
|
|
||||||
# Modules
|
|
||||||
class GaussianPrior(torch.nn.Module):
|
|
||||||
def __init__(self, variance):
|
|
||||||
super().__init__()
|
|
||||||
self.variance = variance
|
|
||||||
|
|
||||||
def forward(self, distances):
|
|
||||||
return gaussian(distances, self.variance)
|
|
||||||
|
|
||||||
|
|
||||||
class RankScaledGaussianPrior(torch.nn.Module):
|
|
||||||
def __init__(self, lambd):
|
|
||||||
super().__init__()
|
|
||||||
self.lambd = lambd
|
|
||||||
|
|
||||||
def forward(self, distances):
|
|
||||||
return rank_scaled_gaussian(distances, self.lambd)
|
|
@ -1,5 +0,0 @@
|
|||||||
"""ProtoTorch modules."""
|
|
||||||
|
|
||||||
from .competitions import *
|
|
||||||
from .pooling import *
|
|
||||||
from .wrappers import LambdaLayer, LossLayer
|
|
Loading…
Reference in New Issue
Block a user