[REFACTOR] Reorganize files and folders
This commit is contained in:
parent
25dbde4e43
commit
093a79d533
@ -1,20 +1,36 @@
|
|||||||
"""ProtoTorch package."""
|
"""ProtoTorch package"""
|
||||||
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
from . import components, datasets, functions, modules, utils
|
from . import (
|
||||||
from .datasets import *
|
datasets,
|
||||||
|
nn,
|
||||||
|
utils,
|
||||||
|
)
|
||||||
|
from .core import (
|
||||||
|
competitions,
|
||||||
|
components,
|
||||||
|
distances,
|
||||||
|
initializers,
|
||||||
|
losses,
|
||||||
|
pooling,
|
||||||
|
)
|
||||||
|
|
||||||
# Core Setup
|
# Core Setup
|
||||||
__version__ = "0.5.0"
|
__version__ = "0.5.0"
|
||||||
|
|
||||||
__all_core__ = [
|
__all_core__ = [
|
||||||
"datasets",
|
"competitions",
|
||||||
"functions",
|
|
||||||
"modules",
|
|
||||||
"components",
|
"components",
|
||||||
|
"core",
|
||||||
|
"datasets",
|
||||||
|
"distances",
|
||||||
|
"initializers",
|
||||||
|
"losses",
|
||||||
|
"nn",
|
||||||
|
"pooling",
|
||||||
"utils",
|
"utils",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""ProtoTorch core"""
|
"""ProtoTorch core"""
|
||||||
|
|
||||||
|
from .competitions import *
|
||||||
from .components import *
|
from .components import *
|
||||||
from .initializers import *
|
from .initializers import *
|
||||||
from .labels import *
|
from .losses import *
|
||||||
|
@ -1,7 +1,31 @@
|
|||||||
"""ProtoTorch Competition Modules."""
|
"""ProtoTorch competitions"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.competitions import knnc, wtac
|
|
||||||
|
|
||||||
|
def wtac(distances: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.LongTensor):
|
||||||
|
"""Winner-Takes-All-Competition.
|
||||||
|
|
||||||
|
Returns the labels corresponding to the winners.
|
||||||
|
|
||||||
|
"""
|
||||||
|
winning_indices = torch.min(distances, dim=1).indices
|
||||||
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
|
return winning_labels
|
||||||
|
|
||||||
|
|
||||||
|
def knnc(distances: torch.Tensor,
|
||||||
|
labels: torch.LongTensor,
|
||||||
|
k: int = 1) -> (torch.LongTensor):
|
||||||
|
"""K-Nearest-Neighbors-Competition.
|
||||||
|
|
||||||
|
Returns the labels corresponding to the winners.
|
||||||
|
|
||||||
|
"""
|
||||||
|
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
||||||
|
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
||||||
|
return winning_labels
|
||||||
|
|
||||||
|
|
||||||
class WTAC(torch.nn.Module):
|
class WTAC(torch.nn.Module):
|
||||||
@ -10,7 +34,6 @@ class WTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, distances, labels):
|
def forward(self, distances, labels):
|
||||||
return wtac(distances, labels)
|
return wtac(distances, labels)
|
||||||
|
|
||||||
@ -21,7 +44,6 @@ class LTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, probs, labels):
|
def forward(self, probs, labels):
|
||||||
return wtac(-1.0 * probs, labels)
|
return wtac(-1.0 * probs, labels)
|
||||||
|
|
||||||
@ -32,7 +54,6 @@ class KNNC(torch.nn.Module):
|
|||||||
Thin wrapper over the `knnc` function.
|
Thin wrapper over the `knnc` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, k=1, **kwargs):
|
def __init__(self, k=1, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.k = k
|
self.k = k
|
261
prototorch/core/distances.py
Normal file
261
prototorch/core/distances.py
Normal file
@ -0,0 +1,261 @@
|
|||||||
|
"""ProtoTorch distances"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# from prototorch.functions.helper import (
|
||||||
|
# _check_shapes,
|
||||||
|
# _int_and_mixed_shape,
|
||||||
|
# equal_int_shape,
|
||||||
|
# get_flat,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
def squared_euclidean_distance(x, y):
|
||||||
|
r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
|
||||||
|
|
||||||
|
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
||||||
|
|
||||||
|
**Alias:**
|
||||||
|
``prototorch.functions.distances.sed``
|
||||||
|
"""
|
||||||
|
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||||
|
expanded_x = x.unsqueeze(dim=1)
|
||||||
|
batchwise_difference = y - expanded_x
|
||||||
|
differences_raised = torch.pow(batchwise_difference, 2)
|
||||||
|
distances = torch.sum(differences_raised, axis=2)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_distance(x, y):
|
||||||
|
r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
|
||||||
|
|
||||||
|
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
||||||
|
|
||||||
|
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||||
|
:rtype: `torch.tensor`
|
||||||
|
"""
|
||||||
|
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||||
|
distances_raised = squared_euclidean_distance(x, y)
|
||||||
|
distances = torch.sqrt(distances_raised)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_distance_v2(x, y):
|
||||||
|
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||||
|
diff = y - x.unsqueeze(1)
|
||||||
|
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||||
|
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||||
|
# batch diagonal. See:
|
||||||
|
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||||
|
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||||
|
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
||||||
|
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
||||||
|
# print(f"{distances.shape=}") # (nx, ny)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def lpnorm_distance(x, y, p):
|
||||||
|
r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
||||||
|
Also known as Minkowski distance.
|
||||||
|
|
||||||
|
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
||||||
|
|
||||||
|
Calls ``torch.cdist``
|
||||||
|
|
||||||
|
:param p: p parameter of the lp norm
|
||||||
|
"""
|
||||||
|
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||||
|
distances = torch.cdist(x, y, p=p)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def omega_distance(x, y, omega):
|
||||||
|
r"""Omega distance.
|
||||||
|
|
||||||
|
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
||||||
|
|
||||||
|
:param `torch.tensor` omega: Two dimensional matrix
|
||||||
|
"""
|
||||||
|
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||||
|
projected_x = x @ omega
|
||||||
|
projected_y = y @ omega
|
||||||
|
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def lomega_distance(x, y, omegas):
|
||||||
|
r"""Localized Omega distance.
|
||||||
|
|
||||||
|
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
||||||
|
|
||||||
|
:param `torch.tensor` omegas: Three dimensional matrix
|
||||||
|
"""
|
||||||
|
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||||
|
projected_x = x @ omegas
|
||||||
|
projected_y = torch.diagonal(y @ omegas).T
|
||||||
|
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||||
|
batchwise_difference = expanded_y - projected_x
|
||||||
|
differences_squared = batchwise_difference**2
|
||||||
|
distances = torch.sum(differences_squared, dim=2)
|
||||||
|
distances = distances.permute(1, 0)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
# def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
||||||
|
# r"""Computes an euclidean distances matrix given two distinct vectors.
|
||||||
|
# last dimension must be the vector dimension!
|
||||||
|
# compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
||||||
|
|
||||||
|
# - ``x.shape = (number_of_x_vectors, vector_dim)``
|
||||||
|
# - ``y.shape = (number_of_y_vectors, vector_dim)``
|
||||||
|
|
||||||
|
# output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
||||||
|
# """
|
||||||
|
# for tensor in [x, y]:
|
||||||
|
# if tensor.ndim != 2:
|
||||||
|
# raise ValueError(
|
||||||
|
# "The tensor dimension must be two. You provide: tensor.ndim=" +
|
||||||
|
# str(tensor.ndim) + ".")
|
||||||
|
# if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||||
|
# raise ValueError(
|
||||||
|
# "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
||||||
|
# + str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" +
|
||||||
|
# str(tuple(y.shape)[1]) + ".")
|
||||||
|
|
||||||
|
# y = torch.transpose(y)
|
||||||
|
|
||||||
|
# diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
|
||||||
|
# torch.sum(y**2, axis=0, keepdims=True))
|
||||||
|
|
||||||
|
# if not squared:
|
||||||
|
# if epsilon == 0:
|
||||||
|
# diss = torch.sqrt(diss)
|
||||||
|
# else:
|
||||||
|
# diss = torch.sqrt(torch.max(diss, epsilon))
|
||||||
|
|
||||||
|
# return diss
|
||||||
|
|
||||||
|
# def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
||||||
|
# r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
|
||||||
|
|
||||||
|
# For more info about Tangen distances see
|
||||||
|
|
||||||
|
# DOI:10.1109/IJCNN.2016.7727534.
|
||||||
|
|
||||||
|
# The subspaces is always assumed as transposed and must be orthogonal!
|
||||||
|
# For local non sparse signals subspaces must be provided!
|
||||||
|
|
||||||
|
# - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||||
|
# - shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||||
|
# - shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||||
|
|
||||||
|
# subspace should be orthogonalized
|
||||||
|
# Pytorch implementation of Sascha Saralajew's tensorflow code.
|
||||||
|
# Translation by Christoph Raab
|
||||||
|
# """
|
||||||
|
# signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||||
|
# proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
||||||
|
# subspace_int_shape = tuple(subspaces.shape)
|
||||||
|
|
||||||
|
# # check if the shapes are correct
|
||||||
|
# _check_shapes(signal_int_shape, proto_int_shape)
|
||||||
|
|
||||||
|
# atom_axes = list(range(3, len(signal_int_shape)))
|
||||||
|
# # for sparse signals, we use the memory efficient implementation
|
||||||
|
# if signal_int_shape[1] == 1:
|
||||||
|
# signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
|
||||||
|
|
||||||
|
# if len(atom_axes) > 1:
|
||||||
|
# protos = torch.reshape(protos, [proto_shape[0], -1])
|
||||||
|
|
||||||
|
# if subspaces.ndim == 2:
|
||||||
|
# # clean solution without map if the matrix_scope is global
|
||||||
|
# projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||||
|
# subspaces, torch.transpose(subspaces))
|
||||||
|
|
||||||
|
# projected_signals = torch.dot(signals, projectors)
|
||||||
|
# projected_protos = torch.dot(protos, projectors)
|
||||||
|
|
||||||
|
# diss = euclidean_distance_matrix(projected_signals,
|
||||||
|
# projected_protos,
|
||||||
|
# squared=squared,
|
||||||
|
# epsilon=epsilon)
|
||||||
|
|
||||||
|
# diss = torch.reshape(
|
||||||
|
# diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||||
|
|
||||||
|
# return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
|
# else:
|
||||||
|
|
||||||
|
# # no solution without map possible --> memory efficient but slow!
|
||||||
|
# projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||||
|
# subspaces,
|
||||||
|
# subspaces) # K.batch_dot(subspaces, subspaces, [2, 2])
|
||||||
|
|
||||||
|
# projected_protos = (protos @ subspaces
|
||||||
|
# ).T # K.batch_dot(projectors, protos, [1, 1]))
|
||||||
|
|
||||||
|
# def projected_norm(projector):
|
||||||
|
# return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
||||||
|
|
||||||
|
# diss = (torch.transpose(map(projected_norm, projectors)) -
|
||||||
|
# 2 * torch.dot(signals, projected_protos) +
|
||||||
|
# torch.sum(projected_protos**2, axis=0, keepdims=True))
|
||||||
|
|
||||||
|
# if not squared:
|
||||||
|
# if epsilon == 0:
|
||||||
|
# diss = torch.sqrt(diss)
|
||||||
|
# else:
|
||||||
|
# diss = torch.sqrt(torch.max(diss, epsilon))
|
||||||
|
|
||||||
|
# diss = torch.reshape(
|
||||||
|
# diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||||
|
|
||||||
|
# return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# signals = signals.permute([0, 2, 1] + atom_axes)
|
||||||
|
|
||||||
|
# diff = signals - protos
|
||||||
|
|
||||||
|
# # global tangent space
|
||||||
|
# if subspaces.ndim == 2:
|
||||||
|
# # Scope Projectors
|
||||||
|
# projectors = subspaces #
|
||||||
|
|
||||||
|
# # Scope: Tangentspace Projections
|
||||||
|
# diff = torch.reshape(
|
||||||
|
# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||||
|
# projected_diff = diff @ projectors
|
||||||
|
# projected_diff = torch.reshape(
|
||||||
|
# projected_diff,
|
||||||
|
# (signal_shape[0], signal_shape[2], signal_shape[1]) +
|
||||||
|
# signal_shape[3:],
|
||||||
|
# )
|
||||||
|
|
||||||
|
# diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
|
# return diss.permute([0, 2, 1])
|
||||||
|
|
||||||
|
# # local tangent spaces
|
||||||
|
# else:
|
||||||
|
# # Scope: Calculate Projectors
|
||||||
|
# projectors = subspaces
|
||||||
|
|
||||||
|
# # Scope: Tangentspace Projections
|
||||||
|
# diff = torch.reshape(
|
||||||
|
# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||||
|
# diff = diff.permute([1, 0, 2])
|
||||||
|
# projected_diff = torch.bmm(diff, projectors)
|
||||||
|
# projected_diff = torch.reshape(
|
||||||
|
# projected_diff,
|
||||||
|
# (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||||
|
# signal_shape[3:],
|
||||||
|
# )
|
||||||
|
|
||||||
|
# diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
|
# return diss.permute([1, 0, 2]).squeeze(-1)
|
||||||
|
|
||||||
|
# Aliases
|
||||||
|
sed = squared_euclidean_distance
|
151
prototorch/core/losses.py
Normal file
151
prototorch/core/losses.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
"""ProtoTorch losses"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..nn.activations import get_activation
|
||||||
|
|
||||||
|
|
||||||
|
# Helpers
|
||||||
|
def _get_matcher(targets, labels):
|
||||||
|
"""Returns a boolean tensor."""
|
||||||
|
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# if the labels are one-hot vectors
|
||||||
|
num_classes = targets.size()[1]
|
||||||
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dp_dm(distances, targets, plabels, with_indices=False):
|
||||||
|
"""Returns the d+ and d- values for a batch of distances."""
|
||||||
|
matcher = _get_matcher(targets, plabels)
|
||||||
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
|
|
||||||
|
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||||
|
d_matching = torch.where(matcher, distances, inf)
|
||||||
|
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||||
|
dp = torch.min(d_matching, dim=-1, keepdim=True)
|
||||||
|
dm = torch.min(d_unmatching, dim=-1, keepdim=True)
|
||||||
|
if with_indices:
|
||||||
|
return dp, dm
|
||||||
|
return dp.values, dm.values
|
||||||
|
|
||||||
|
|
||||||
|
# GLVQ
|
||||||
|
def glvq_loss(distances, target_labels, prototype_labels):
|
||||||
|
"""GLVQ loss function with support for one-hot labels."""
|
||||||
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||||
|
mu = (dp - dm) / (dp + dm)
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
def lvq1_loss(distances, target_labels, prototype_labels):
|
||||||
|
"""LVQ1 loss function with support for one-hot labels.
|
||||||
|
|
||||||
|
See Section 4 [Sado&Yamada]
|
||||||
|
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||||
|
"""
|
||||||
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||||
|
mu = dp
|
||||||
|
mu[dp > dm] = -dm[dp > dm]
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
def lvq21_loss(distances, target_labels, prototype_labels):
|
||||||
|
"""LVQ2.1 loss function with support for one-hot labels.
|
||||||
|
|
||||||
|
See Section 4 [Sado&Yamada]
|
||||||
|
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||||
|
"""
|
||||||
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||||
|
mu = dp - dm
|
||||||
|
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
# Probabilistic
|
||||||
|
def _get_class_probabilities(probabilities, targets, prototype_labels):
|
||||||
|
# Create Label Mapping
|
||||||
|
uniques = prototype_labels.unique(sorted=True).tolist()
|
||||||
|
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||||
|
|
||||||
|
target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist())))
|
||||||
|
|
||||||
|
whole = probabilities.sum(dim=1)
|
||||||
|
correct = probabilities[torch.arange(len(probabilities)), target_indices]
|
||||||
|
wrong = whole - correct
|
||||||
|
|
||||||
|
return whole, correct, wrong
|
||||||
|
|
||||||
|
|
||||||
|
def nllr_loss(probabilities, targets, prototype_labels):
|
||||||
|
"""Compute the Negative Log-Likelihood Ratio loss."""
|
||||||
|
_, correct, wrong = _get_class_probabilities(probabilities, targets,
|
||||||
|
prototype_labels)
|
||||||
|
|
||||||
|
likelihood = correct / wrong
|
||||||
|
log_likelihood = torch.log(likelihood)
|
||||||
|
return -1.0 * log_likelihood
|
||||||
|
|
||||||
|
|
||||||
|
def rslvq_loss(probabilities, targets, prototype_labels):
|
||||||
|
"""Compute the Robust Soft Learning Vector Quantization (RSLVQ) loss."""
|
||||||
|
whole, correct, _ = _get_class_probabilities(probabilities, targets,
|
||||||
|
prototype_labels)
|
||||||
|
|
||||||
|
likelihood = correct / whole
|
||||||
|
log_likelihood = torch.log(likelihood)
|
||||||
|
return -1.0 * log_likelihood
|
||||||
|
|
||||||
|
|
||||||
|
class GLVQLoss(torch.nn.Module):
|
||||||
|
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.margin = margin
|
||||||
|
self.squashing = get_activation(squashing)
|
||||||
|
self.beta = torch.tensor(beta)
|
||||||
|
|
||||||
|
def forward(self, outputs, targets):
|
||||||
|
distances, plabels = outputs
|
||||||
|
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||||
|
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||||
|
return torch.sum(batch_loss, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class NeuralGasEnergy(torch.nn.Module):
|
||||||
|
def __init__(self, lm, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.lm = lm
|
||||||
|
|
||||||
|
def forward(self, d):
|
||||||
|
order = torch.argsort(d, dim=1)
|
||||||
|
ranks = torch.argsort(order, dim=1)
|
||||||
|
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
|
||||||
|
|
||||||
|
return cost, order
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"lambda: {self.lm}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _nghood_fn(rankings, lm):
|
||||||
|
return torch.exp(-rankings / lm)
|
||||||
|
|
||||||
|
|
||||||
|
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
||||||
|
def __init__(self, topology_layer, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.topology_layer = topology_layer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _nghood_fn(rankings, topology):
|
||||||
|
winner = rankings[:, 0]
|
||||||
|
|
||||||
|
weights = torch.zeros_like(rankings, dtype=torch.float)
|
||||||
|
weights[torch.arange(rankings.shape[0]), winner] = 1.0
|
||||||
|
|
||||||
|
neighbours = topology.get_neighbours(winner)
|
||||||
|
|
||||||
|
weights[neighbours] = 0.1
|
||||||
|
|
||||||
|
return weights
|
104
prototorch/core/pooling.py
Normal file
104
prototorch/core/pooling.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
"""ProtoTorch pooling"""
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def stratify_with(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor,
|
||||||
|
fn: Callable,
|
||||||
|
fill_value: float = 0.0) -> (torch.Tensor):
|
||||||
|
"""Apply an arbitrary stratification strategy on the columns on `values`.
|
||||||
|
|
||||||
|
The outputs correspond to sorted labels.
|
||||||
|
"""
|
||||||
|
clabels = torch.unique(labels, dim=0, sorted=True)
|
||||||
|
num_classes = clabels.size()[0]
|
||||||
|
if values.size()[1] == num_classes:
|
||||||
|
# skip if stratification is trivial
|
||||||
|
return values
|
||||||
|
batch_size = values.size()[0]
|
||||||
|
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
||||||
|
filler = torch.full_like(values.T, fill_value=fill_value)
|
||||||
|
for i, cl in enumerate(clabels):
|
||||||
|
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# if the labels are one-hot vectors
|
||||||
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
|
cdists = torch.where(matcher, values.T, filler).T
|
||||||
|
winning_values[i] = fn(cdists)
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# Transpose to return with `batch_size` first and
|
||||||
|
# reverse the columns to fix the ordering of the classes
|
||||||
|
return torch.flip(winning_values.T, dims=(1, ))
|
||||||
|
|
||||||
|
return winning_values.T # return with `batch_size` first
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_sum_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise sum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
|
||||||
|
fill_value=0.0)
|
||||||
|
return winning_values
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_min_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise minimum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
|
||||||
|
fill_value=float("inf"))
|
||||||
|
return winning_values
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_max_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise maximum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
|
||||||
|
fill_value=-1.0 * float("inf"))
|
||||||
|
return winning_values
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_prod_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise maximum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
|
||||||
|
fill_value=1.0)
|
||||||
|
return winning_values
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedSumPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_sum_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedProdPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_prod_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedMinPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_min_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedMaxPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_max_pooling(values, labels)
|
@ -1,58 +0,0 @@
|
|||||||
"""ProtoTorch losses."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from prototorch.functions.activations import get_activation
|
|
||||||
from prototorch.functions.losses import glvq_loss
|
|
||||||
|
|
||||||
|
|
||||||
class GLVQLoss(torch.nn.Module):
|
|
||||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.margin = margin
|
|
||||||
self.squashing = get_activation(squashing)
|
|
||||||
self.beta = torch.tensor(beta)
|
|
||||||
|
|
||||||
def forward(self, outputs, targets):
|
|
||||||
distances, plabels = outputs
|
|
||||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
|
||||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
|
||||||
return torch.sum(batch_loss, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
class NeuralGasEnergy(torch.nn.Module):
|
|
||||||
def __init__(self, lm, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.lm = lm
|
|
||||||
|
|
||||||
def forward(self, d):
|
|
||||||
order = torch.argsort(d, dim=1)
|
|
||||||
ranks = torch.argsort(order, dim=1)
|
|
||||||
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
|
|
||||||
|
|
||||||
return cost, order
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"lambda: {self.lm}"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _nghood_fn(rankings, lm):
|
|
||||||
return torch.exp(-rankings / lm)
|
|
||||||
|
|
||||||
|
|
||||||
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
|
||||||
def __init__(self, topology_layer, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.topology_layer = topology_layer
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _nghood_fn(rankings, topology):
|
|
||||||
winner = rankings[:, 0]
|
|
||||||
|
|
||||||
weights = torch.zeros_like(rankings, dtype=torch.float)
|
|
||||||
weights[torch.arange(rankings.shape[0]), winner] = 1.0
|
|
||||||
|
|
||||||
neighbours = topology.get_neighbours(winner)
|
|
||||||
|
|
||||||
weights[neighbours] = 0.1
|
|
||||||
|
|
||||||
return weights
|
|
@ -1,31 +0,0 @@
|
|||||||
"""ProtoTorch Pooling Modules."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from prototorch.functions.pooling import (stratified_max_pooling,
|
|
||||||
stratified_min_pooling,
|
|
||||||
stratified_prod_pooling,
|
|
||||||
stratified_sum_pooling)
|
|
||||||
|
|
||||||
|
|
||||||
class StratifiedSumPooling(torch.nn.Module):
|
|
||||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
|
||||||
def forward(self, values, labels):
|
|
||||||
return stratified_sum_pooling(values, labels)
|
|
||||||
|
|
||||||
|
|
||||||
class StratifiedProdPooling(torch.nn.Module):
|
|
||||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
|
||||||
def forward(self, values, labels):
|
|
||||||
return stratified_prod_pooling(values, labels)
|
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMinPooling(torch.nn.Module):
|
|
||||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
|
||||||
def forward(self, values, labels):
|
|
||||||
return stratified_min_pooling(values, labels)
|
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMaxPooling(torch.nn.Module):
|
|
||||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
|
||||||
def forward(self, values, labels):
|
|
||||||
return stratified_max_pooling(values, labels)
|
|
4
prototorch/nn/__init__.py
Normal file
4
prototorch/nn/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
"""ProtoTorch Neural Network Module"""
|
||||||
|
|
||||||
|
from .activations import *
|
||||||
|
from .wrappers import *
|
62
prototorch/nn/activations.py
Normal file
62
prototorch/nn/activations.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
"""ProtoTorch activations"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ACTIVATIONS = dict()
|
||||||
|
|
||||||
|
|
||||||
|
def register_activation(fn):
|
||||||
|
"""Add the activation function to the registry."""
|
||||||
|
name = fn.__name__
|
||||||
|
ACTIVATIONS[name] = fn
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
@register_activation
|
||||||
|
def identity(x, beta=0.0):
|
||||||
|
"""Identity activation function.
|
||||||
|
|
||||||
|
Definition:
|
||||||
|
:math:`f(x) = x`
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
beta (`float`): Ignored.
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@register_activation
|
||||||
|
def sigmoid_beta(x, beta=10.0):
|
||||||
|
r"""Sigmoid activation function with scaling.
|
||||||
|
|
||||||
|
Definition:
|
||||||
|
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
beta (`float`): Scaling parameter :math:`\beta`
|
||||||
|
"""
|
||||||
|
out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@register_activation
|
||||||
|
def swish_beta(x, beta=10.0):
|
||||||
|
r"""Swish activation function with scaling.
|
||||||
|
|
||||||
|
Definition:
|
||||||
|
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
beta (`float`): Scaling parameter :math:`\beta`
|
||||||
|
"""
|
||||||
|
out = x * sigmoid_beta(x, beta=beta)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation(funcname):
|
||||||
|
"""Deserialize the activation function."""
|
||||||
|
if callable(funcname):
|
||||||
|
return funcname
|
||||||
|
if funcname in ACTIVATIONS:
|
||||||
|
return ACTIVATIONS.get(funcname)
|
||||||
|
raise NameError(f"Activation {funcname} was not found.")
|
@ -1,4 +1,4 @@
|
|||||||
"""ProtoTorch Wrappers."""
|
"""ProtoTorch wrappers."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
Loading…
Reference in New Issue
Block a user