[REFACTOR] Reorganize files and folders

This commit is contained in:
Jensun Ravichandran 2021-06-12 20:38:16 +02:00
parent 25dbde4e43
commit 093a79d533
11 changed files with 633 additions and 102 deletions

View File

@ -1,20 +1,36 @@
"""ProtoTorch package."""
"""ProtoTorch package"""
import pkgutil
import pkg_resources
from . import components, datasets, functions, modules, utils
from .datasets import *
from . import (
datasets,
nn,
utils,
)
from .core import (
competitions,
components,
distances,
initializers,
losses,
pooling,
)
# Core Setup
__version__ = "0.5.0"
__all_core__ = [
"datasets",
"functions",
"modules",
"competitions",
"components",
"core",
"datasets",
"distances",
"initializers",
"losses",
"nn",
"pooling",
"utils",
]

View File

@ -1,5 +1,6 @@
"""ProtoTorch core"""
from .competitions import *
from .components import *
from .initializers import *
from .labels import *
from .losses import *

View File

@ -1,7 +1,31 @@
"""ProtoTorch Competition Modules."""
"""ProtoTorch competitions"""
import torch
from prototorch.functions.competitions import knnc, wtac
def wtac(distances: torch.Tensor,
labels: torch.LongTensor) -> (torch.LongTensor):
"""Winner-Takes-All-Competition.
Returns the labels corresponding to the winners.
"""
winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze()
return winning_labels
def knnc(distances: torch.Tensor,
labels: torch.LongTensor,
k: int = 1) -> (torch.LongTensor):
"""K-Nearest-Neighbors-Competition.
Returns the labels corresponding to the winners.
"""
winning_indices = torch.topk(-distances, k=k, dim=1).indices
winning_labels = torch.mode(labels[winning_indices], dim=1).values
return winning_labels
class WTAC(torch.nn.Module):
@ -10,7 +34,6 @@ class WTAC(torch.nn.Module):
Thin wrapper over the `wtac` function.
"""
def forward(self, distances, labels):
return wtac(distances, labels)
@ -21,7 +44,6 @@ class LTAC(torch.nn.Module):
Thin wrapper over the `wtac` function.
"""
def forward(self, probs, labels):
return wtac(-1.0 * probs, labels)
@ -32,7 +54,6 @@ class KNNC(torch.nn.Module):
Thin wrapper over the `knnc` function.
"""
def __init__(self, k=1, **kwargs):
super().__init__(**kwargs)
self.k = k

View File

@ -0,0 +1,261 @@
"""ProtoTorch distances"""
import numpy as np
import torch
# from prototorch.functions.helper import (
# _check_shapes,
# _int_and_mixed_shape,
# equal_int_shape,
# get_flat,
# )
def squared_euclidean_distance(x, y):
r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
**Alias:**
``prototorch.functions.distances.sed``
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
expanded_x = x.unsqueeze(dim=1)
batchwise_difference = y - expanded_x
differences_raised = torch.pow(batchwise_difference, 2)
distances = torch.sum(differences_raised, axis=2)
return distances
def euclidean_distance(x, y):
r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
:returns: Distance Tensor of shape :math:`X \times Y`
:rtype: `torch.tensor`
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
distances_raised = squared_euclidean_distance(x, y)
distances = torch.sqrt(distances_raised)
return distances
def euclidean_distance_v2(x, y):
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
diff = y - x.unsqueeze(1)
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
# batch diagonal. See:
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
# print(f"{diff.shape=}") # (nx, ny, ndim)
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
# print(f"{distances.shape=}") # (nx, ny)
return distances
def lpnorm_distance(x, y, p):
r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
Also known as Minkowski distance.
Compute :math:`{\| \bm x - \bm y \|}_p`.
Calls ``torch.cdist``
:param p: p parameter of the lp norm
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
distances = torch.cdist(x, y, p=p)
return distances
def omega_distance(x, y, omega):
r"""Omega distance.
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
:param `torch.tensor` omega: Two dimensional matrix
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
projected_x = x @ omega
projected_y = y @ omega
distances = squared_euclidean_distance(projected_x, projected_y)
return distances
def lomega_distance(x, y, omegas):
r"""Localized Omega distance.
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
:param `torch.tensor` omegas: Three dimensional matrix
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
projected_x = x @ omegas
projected_y = torch.diagonal(y @ omegas).T
expanded_y = torch.unsqueeze(projected_y, dim=1)
batchwise_difference = expanded_y - projected_x
differences_squared = batchwise_difference**2
distances = torch.sum(differences_squared, dim=2)
distances = distances.permute(1, 0)
return distances
# def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
# r"""Computes an euclidean distances matrix given two distinct vectors.
# last dimension must be the vector dimension!
# compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
# - ``x.shape = (number_of_x_vectors, vector_dim)``
# - ``y.shape = (number_of_y_vectors, vector_dim)``
# output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
# """
# for tensor in [x, y]:
# if tensor.ndim != 2:
# raise ValueError(
# "The tensor dimension must be two. You provide: tensor.ndim=" +
# str(tensor.ndim) + ".")
# if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
# raise ValueError(
# "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
# + str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" +
# str(tuple(y.shape)[1]) + ".")
# y = torch.transpose(y)
# diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
# torch.sum(y**2, axis=0, keepdims=True))
# if not squared:
# if epsilon == 0:
# diss = torch.sqrt(diss)
# else:
# diss = torch.sqrt(torch.max(diss, epsilon))
# return diss
# def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
# r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
# For more info about Tangen distances see
# DOI:10.1109/IJCNN.2016.7727534.
# The subspaces is always assumed as transposed and must be orthogonal!
# For local non sparse signals subspaces must be provided!
# - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
# - shape(protos): proto_number x dim1 x dim2 x ... x dimN
# - shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
# subspace should be orthogonalized
# Pytorch implementation of Sascha Saralajew's tensorflow code.
# Translation by Christoph Raab
# """
# signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
# proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
# subspace_int_shape = tuple(subspaces.shape)
# # check if the shapes are correct
# _check_shapes(signal_int_shape, proto_int_shape)
# atom_axes = list(range(3, len(signal_int_shape)))
# # for sparse signals, we use the memory efficient implementation
# if signal_int_shape[1] == 1:
# signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
# if len(atom_axes) > 1:
# protos = torch.reshape(protos, [proto_shape[0], -1])
# if subspaces.ndim == 2:
# # clean solution without map if the matrix_scope is global
# projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
# subspaces, torch.transpose(subspaces))
# projected_signals = torch.dot(signals, projectors)
# projected_protos = torch.dot(protos, projectors)
# diss = euclidean_distance_matrix(projected_signals,
# projected_protos,
# squared=squared,
# epsilon=epsilon)
# diss = torch.reshape(
# diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
# return torch.permute(diss, [0, 2, 1])
# else:
# # no solution without map possible --> memory efficient but slow!
# projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
# subspaces,
# subspaces) # K.batch_dot(subspaces, subspaces, [2, 2])
# projected_protos = (protos @ subspaces
# ).T # K.batch_dot(projectors, protos, [1, 1]))
# def projected_norm(projector):
# return torch.sum(torch.dot(signals, projector)**2, axis=1)
# diss = (torch.transpose(map(projected_norm, projectors)) -
# 2 * torch.dot(signals, projected_protos) +
# torch.sum(projected_protos**2, axis=0, keepdims=True))
# if not squared:
# if epsilon == 0:
# diss = torch.sqrt(diss)
# else:
# diss = torch.sqrt(torch.max(diss, epsilon))
# diss = torch.reshape(
# diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
# return torch.permute(diss, [0, 2, 1])
# else:
# signals = signals.permute([0, 2, 1] + atom_axes)
# diff = signals - protos
# # global tangent space
# if subspaces.ndim == 2:
# # Scope Projectors
# projectors = subspaces #
# # Scope: Tangentspace Projections
# diff = torch.reshape(
# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
# projected_diff = diff @ projectors
# projected_diff = torch.reshape(
# projected_diff,
# (signal_shape[0], signal_shape[2], signal_shape[1]) +
# signal_shape[3:],
# )
# diss = torch.norm(projected_diff, 2, dim=-1)
# return diss.permute([0, 2, 1])
# # local tangent spaces
# else:
# # Scope: Calculate Projectors
# projectors = subspaces
# # Scope: Tangentspace Projections
# diff = torch.reshape(
# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
# diff = diff.permute([1, 0, 2])
# projected_diff = torch.bmm(diff, projectors)
# projected_diff = torch.reshape(
# projected_diff,
# (signal_shape[1], signal_shape[0], signal_shape[2]) +
# signal_shape[3:],
# )
# diss = torch.norm(projected_diff, 2, dim=-1)
# return diss.permute([1, 0, 2]).squeeze(-1)
# Aliases
sed = squared_euclidean_distance

151
prototorch/core/losses.py Normal file
View File

@ -0,0 +1,151 @@
"""ProtoTorch losses"""
import torch
from ..nn.activations import get_activation
# Helpers
def _get_matcher(targets, labels):
"""Returns a boolean tensor."""
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
if labels.ndim == 2:
# if the labels are one-hot vectors
num_classes = targets.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
return matcher
def _get_dp_dm(distances, targets, plabels, with_indices=False):
"""Returns the d+ and d- values for a batch of distances."""
matcher = _get_matcher(targets, plabels)
not_matcher = torch.bitwise_not(matcher)
inf = torch.full_like(distances, fill_value=float("inf"))
d_matching = torch.where(matcher, distances, inf)
d_unmatching = torch.where(not_matcher, distances, inf)
dp = torch.min(d_matching, dim=-1, keepdim=True)
dm = torch.min(d_unmatching, dim=-1, keepdim=True)
if with_indices:
return dp, dm
return dp.values, dm.values
# GLVQ
def glvq_loss(distances, target_labels, prototype_labels):
"""GLVQ loss function with support for one-hot labels."""
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = (dp - dm) / (dp + dm)
return mu
def lvq1_loss(distances, target_labels, prototype_labels):
"""LVQ1 loss function with support for one-hot labels.
See Section 4 [Sado&Yamada]
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
"""
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = dp
mu[dp > dm] = -dm[dp > dm]
return mu
def lvq21_loss(distances, target_labels, prototype_labels):
"""LVQ2.1 loss function with support for one-hot labels.
See Section 4 [Sado&Yamada]
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
"""
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = dp - dm
return mu
# Probabilistic
def _get_class_probabilities(probabilities, targets, prototype_labels):
# Create Label Mapping
uniques = prototype_labels.unique(sorted=True).tolist()
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist())))
whole = probabilities.sum(dim=1)
correct = probabilities[torch.arange(len(probabilities)), target_indices]
wrong = whole - correct
return whole, correct, wrong
def nllr_loss(probabilities, targets, prototype_labels):
"""Compute the Negative Log-Likelihood Ratio loss."""
_, correct, wrong = _get_class_probabilities(probabilities, targets,
prototype_labels)
likelihood = correct / wrong
log_likelihood = torch.log(likelihood)
return -1.0 * log_likelihood
def rslvq_loss(probabilities, targets, prototype_labels):
"""Compute the Robust Soft Learning Vector Quantization (RSLVQ) loss."""
whole, correct, _ = _get_class_probabilities(probabilities, targets,
prototype_labels)
likelihood = correct / whole
log_likelihood = torch.log(likelihood)
return -1.0 * log_likelihood
class GLVQLoss(torch.nn.Module):
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
super().__init__(**kwargs)
self.margin = margin
self.squashing = get_activation(squashing)
self.beta = torch.tensor(beta)
def forward(self, outputs, targets):
distances, plabels = outputs
mu = glvq_loss(distances, targets, prototype_labels=plabels)
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
return torch.sum(batch_loss, dim=0)
class NeuralGasEnergy(torch.nn.Module):
def __init__(self, lm, **kwargs):
super().__init__(**kwargs)
self.lm = lm
def forward(self, d):
order = torch.argsort(d, dim=1)
ranks = torch.argsort(order, dim=1)
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
return cost, order
def extra_repr(self):
return f"lambda: {self.lm}"
@staticmethod
def _nghood_fn(rankings, lm):
return torch.exp(-rankings / lm)
class GrowingNeuralGasEnergy(NeuralGasEnergy):
def __init__(self, topology_layer, **kwargs):
super().__init__(**kwargs)
self.topology_layer = topology_layer
@staticmethod
def _nghood_fn(rankings, topology):
winner = rankings[:, 0]
weights = torch.zeros_like(rankings, dtype=torch.float)
weights[torch.arange(rankings.shape[0]), winner] = 1.0
neighbours = topology.get_neighbours(winner)
weights[neighbours] = 0.1
return weights

104
prototorch/core/pooling.py Normal file
View File

@ -0,0 +1,104 @@
"""ProtoTorch pooling"""
from typing import Callable
import torch
def stratify_with(values: torch.Tensor,
labels: torch.LongTensor,
fn: Callable,
fill_value: float = 0.0) -> (torch.Tensor):
"""Apply an arbitrary stratification strategy on the columns on `values`.
The outputs correspond to sorted labels.
"""
clabels = torch.unique(labels, dim=0, sorted=True)
num_classes = clabels.size()[0]
if values.size()[1] == num_classes:
# skip if stratification is trivial
return values
batch_size = values.size()[0]
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
filler = torch.full_like(values.T, fill_value=fill_value)
for i, cl in enumerate(clabels):
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
if labels.ndim == 2:
# if the labels are one-hot vectors
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
cdists = torch.where(matcher, values.T, filler).T
winning_values[i] = fn(cdists)
if labels.ndim == 2:
# Transpose to return with `batch_size` first and
# reverse the columns to fix the ordering of the classes
return torch.flip(winning_values.T, dims=(1, ))
return winning_values.T # return with `batch_size` first
def stratified_sum_pooling(values: torch.Tensor,
labels: torch.LongTensor) -> (torch.Tensor):
"""Group-wise sum."""
winning_values = stratify_with(
values,
labels,
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
fill_value=0.0)
return winning_values
def stratified_min_pooling(values: torch.Tensor,
labels: torch.LongTensor) -> (torch.Tensor):
"""Group-wise minimum."""
winning_values = stratify_with(
values,
labels,
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
fill_value=float("inf"))
return winning_values
def stratified_max_pooling(values: torch.Tensor,
labels: torch.LongTensor) -> (torch.Tensor):
"""Group-wise maximum."""
winning_values = stratify_with(
values,
labels,
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
fill_value=-1.0 * float("inf"))
return winning_values
def stratified_prod_pooling(values: torch.Tensor,
labels: torch.LongTensor) -> (torch.Tensor):
"""Group-wise maximum."""
winning_values = stratify_with(
values,
labels,
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
fill_value=1.0)
return winning_values
class StratifiedSumPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_sum_pooling` function."""
def forward(self, values, labels):
return stratified_sum_pooling(values, labels)
class StratifiedProdPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_prod_pooling` function."""
def forward(self, values, labels):
return stratified_prod_pooling(values, labels)
class StratifiedMinPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_min_pooling` function."""
def forward(self, values, labels):
return stratified_min_pooling(values, labels)
class StratifiedMaxPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_max_pooling` function."""
def forward(self, values, labels):
return stratified_max_pooling(values, labels)

View File

@ -1,58 +0,0 @@
"""ProtoTorch losses."""
import torch
from prototorch.functions.activations import get_activation
from prototorch.functions.losses import glvq_loss
class GLVQLoss(torch.nn.Module):
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
super().__init__(**kwargs)
self.margin = margin
self.squashing = get_activation(squashing)
self.beta = torch.tensor(beta)
def forward(self, outputs, targets):
distances, plabels = outputs
mu = glvq_loss(distances, targets, prototype_labels=plabels)
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
return torch.sum(batch_loss, dim=0)
class NeuralGasEnergy(torch.nn.Module):
def __init__(self, lm, **kwargs):
super().__init__(**kwargs)
self.lm = lm
def forward(self, d):
order = torch.argsort(d, dim=1)
ranks = torch.argsort(order, dim=1)
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
return cost, order
def extra_repr(self):
return f"lambda: {self.lm}"
@staticmethod
def _nghood_fn(rankings, lm):
return torch.exp(-rankings / lm)
class GrowingNeuralGasEnergy(NeuralGasEnergy):
def __init__(self, topology_layer, **kwargs):
super().__init__(**kwargs)
self.topology_layer = topology_layer
@staticmethod
def _nghood_fn(rankings, topology):
winner = rankings[:, 0]
weights = torch.zeros_like(rankings, dtype=torch.float)
weights[torch.arange(rankings.shape[0]), winner] = 1.0
neighbours = topology.get_neighbours(winner)
weights[neighbours] = 0.1
return weights

View File

@ -1,31 +0,0 @@
"""ProtoTorch Pooling Modules."""
import torch
from prototorch.functions.pooling import (stratified_max_pooling,
stratified_min_pooling,
stratified_prod_pooling,
stratified_sum_pooling)
class StratifiedSumPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_sum_pooling` function."""
def forward(self, values, labels):
return stratified_sum_pooling(values, labels)
class StratifiedProdPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_prod_pooling` function."""
def forward(self, values, labels):
return stratified_prod_pooling(values, labels)
class StratifiedMinPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_min_pooling` function."""
def forward(self, values, labels):
return stratified_min_pooling(values, labels)
class StratifiedMaxPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_max_pooling` function."""
def forward(self, values, labels):
return stratified_max_pooling(values, labels)

View File

@ -0,0 +1,4 @@
"""ProtoTorch Neural Network Module"""
from .activations import *
from .wrappers import *

View File

@ -0,0 +1,62 @@
"""ProtoTorch activations"""
import torch
ACTIVATIONS = dict()
def register_activation(fn):
"""Add the activation function to the registry."""
name = fn.__name__
ACTIVATIONS[name] = fn
return fn
@register_activation
def identity(x, beta=0.0):
"""Identity activation function.
Definition:
:math:`f(x) = x`
Keyword Arguments:
beta (`float`): Ignored.
"""
return x
@register_activation
def sigmoid_beta(x, beta=10.0):
r"""Sigmoid activation function with scaling.
Definition:
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
Keyword Arguments:
beta (`float`): Scaling parameter :math:`\beta`
"""
out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x))
return out
@register_activation
def swish_beta(x, beta=10.0):
r"""Swish activation function with scaling.
Definition:
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
Keyword Arguments:
beta (`float`): Scaling parameter :math:`\beta`
"""
out = x * sigmoid_beta(x, beta=beta)
return out
def get_activation(funcname):
"""Deserialize the activation function."""
if callable(funcname):
return funcname
if funcname in ACTIVATIONS:
return ACTIVATIONS.get(funcname)
raise NameError(f"Activation {funcname} was not found.")

View File

@ -1,4 +1,4 @@
"""ProtoTorch Wrappers."""
"""ProtoTorch wrappers."""
import torch