refactor(api)!: merge the new api changes into dev
BREAKING CHANGE: remove the following `prototorch/functions/*` `prototorch/components/*` `prototorch/modules/*` BREAKING CHANGE: move `initializers` into the `prototorch.initializers` namespace from the `prototorch.components` namespace BREAKING CHANGE: `functions` and `modules` and moved into `core` and `nn`
This commit is contained in:
10
prototorch/core/__init__.py
Normal file
10
prototorch/core/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""ProtoTorch core"""
|
||||
|
||||
from .competitions import *
|
||||
from .components import *
|
||||
from .distances import *
|
||||
from .initializers import *
|
||||
from .losses import *
|
||||
from .pooling import *
|
||||
from .similarities import *
|
||||
from .transforms import *
|
89
prototorch/core/competitions.py
Normal file
89
prototorch/core/competitions.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""ProtoTorch competitions"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def wtac(distances: torch.Tensor, labels: torch.LongTensor):
|
||||
"""Winner-Takes-All-Competition.
|
||||
|
||||
Returns the labels corresponding to the winners.
|
||||
|
||||
"""
|
||||
winning_indices = torch.min(distances, dim=1).indices
|
||||
winning_labels = labels[winning_indices].squeeze()
|
||||
return winning_labels
|
||||
|
||||
|
||||
def knnc(distances: torch.Tensor, labels: torch.LongTensor, k: int = 1):
|
||||
"""K-Nearest-Neighbors-Competition.
|
||||
|
||||
Returns the labels corresponding to the winners.
|
||||
|
||||
"""
|
||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
||||
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
||||
return winning_labels
|
||||
|
||||
|
||||
def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
|
||||
"""Classification-By-Components Competition.
|
||||
|
||||
Returns probability distributions over the classes.
|
||||
|
||||
`detections` must be of shape [batch_size, num_components].
|
||||
`reasonings` must be of shape [num_components, num_classes, 2].
|
||||
|
||||
"""
|
||||
A, B = reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||
pk = A
|
||||
nk = (1 - A) * B
|
||||
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
||||
probs = numerator / (pk + nk).sum(1)
|
||||
return probs
|
||||
|
||||
|
||||
class WTAC(torch.nn.Module):
|
||||
"""Winner-Takes-All-Competition Layer.
|
||||
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, distances, labels):
|
||||
return wtac(distances, labels)
|
||||
|
||||
|
||||
class LTAC(torch.nn.Module):
|
||||
"""Loser-Takes-All-Competition Layer.
|
||||
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, probs, labels):
|
||||
return wtac(-1.0 * probs, labels)
|
||||
|
||||
|
||||
class KNNC(torch.nn.Module):
|
||||
"""K-Nearest-Neighbors-Competition.
|
||||
|
||||
Thin wrapper over the `knnc` function.
|
||||
|
||||
"""
|
||||
def __init__(self, k=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.k = k
|
||||
|
||||
def forward(self, distances, labels):
|
||||
return knnc(distances, labels, k=self.k)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"k: {self.k}"
|
||||
|
||||
|
||||
class CBCC(torch.nn.Module):
|
||||
"""Classification-By-Components Competition.
|
||||
|
||||
Thin wrapper over the `cbcc` function.
|
||||
|
||||
"""
|
||||
def forward(self, detections, reasonings):
|
||||
return cbcc(detections, reasonings)
|
370
prototorch/core/components.py
Normal file
370
prototorch/core/components.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""ProtoTorch components"""
|
||||
|
||||
import inspect
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ..utils import parse_distribution
|
||||
from .initializers import (
|
||||
AbstractClassAwareCompInitializer,
|
||||
AbstractComponentsInitializer,
|
||||
AbstractLabelsInitializer,
|
||||
AbstractReasoningsInitializer,
|
||||
LabelsInitializer,
|
||||
PurePositiveReasoningsInitializer,
|
||||
RandomReasoningsInitializer,
|
||||
)
|
||||
|
||||
|
||||
def validate_initializer(initializer, instanceof):
|
||||
"""Check if the initializer is valid."""
|
||||
if not isinstance(initializer, instanceof):
|
||||
emsg = f"`initializer` has to be an instance " \
|
||||
f"of some subtype of {instanceof}. " \
|
||||
f"You have provided: {initializer} instead. "
|
||||
helpmsg = ""
|
||||
if inspect.isclass(initializer):
|
||||
helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \
|
||||
f"with the brackets instead of just {initializer.__name__}?"
|
||||
raise TypeError(emsg + helpmsg)
|
||||
return True
|
||||
|
||||
|
||||
def gencat(ins, attr, init, *iargs, **ikwargs):
|
||||
"""Generate new items and concatenate with existing items."""
|
||||
new_items = init.generate(*iargs, **ikwargs)
|
||||
if hasattr(ins, attr):
|
||||
items = torch.cat([getattr(ins, attr), new_items])
|
||||
else:
|
||||
items = new_items
|
||||
return items, new_items
|
||||
|
||||
|
||||
def removeind(ins, attr, indices):
|
||||
"""Remove items at specified indices."""
|
||||
mask = torch.ones(len(ins), dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
items = getattr(ins, attr)[mask]
|
||||
return items, mask
|
||||
|
||||
|
||||
def get_cikwargs(init, distribution):
|
||||
"""Return appropriate key-word arguments for a component initializer."""
|
||||
if isinstance(init, AbstractClassAwareCompInitializer):
|
||||
cikwargs = dict(distribution=distribution)
|
||||
else:
|
||||
distribution = parse_distribution(distribution)
|
||||
num_components = sum(distribution.values())
|
||||
cikwargs = dict(num_components=num_components)
|
||||
return cikwargs
|
||||
|
||||
|
||||
class AbstractComponents(torch.nn.Module):
|
||||
"""Abstract class for all components modules."""
|
||||
@property
|
||||
def num_components(self):
|
||||
"""Current number of components."""
|
||||
return len(self._components)
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Detached Tensor containing the components."""
|
||||
return self._components.detach().cpu()
|
||||
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
|
||||
def extra_repr(self):
|
||||
return f"components: (shape: {tuple(self._components.shape)})"
|
||||
|
||||
def __len__(self):
|
||||
return self.num_components
|
||||
|
||||
|
||||
class Components(AbstractComponents):
|
||||
"""A set of adaptable Tensors."""
|
||||
def __init__(self, num_components: int,
|
||||
initializer: AbstractComponentsInitializer):
|
||||
super().__init__()
|
||||
self.add_components(num_components, initializer)
|
||||
|
||||
def add_components(self, num_components: int,
|
||||
initializer: AbstractComponentsInitializer):
|
||||
"""Generate and add new components."""
|
||||
assert validate_initializer(initializer, AbstractComponentsInitializer)
|
||||
_components, new_components = gencat(self, "_components", initializer,
|
||||
num_components)
|
||||
self._register_components(_components)
|
||||
return new_components
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components at specified indices."""
|
||||
_components, mask = removeind(self, "_components", indices)
|
||||
self._register_components(_components)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components parameter Tensor."""
|
||||
return self._components
|
||||
|
||||
|
||||
class AbstractLabels(torch.nn.Module):
|
||||
"""Abstract class for all labels modules."""
|
||||
@property
|
||||
def labels(self):
|
||||
return self._labels.cpu()
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
return len(self._labels)
|
||||
|
||||
@property
|
||||
def unique_labels(self):
|
||||
return torch.unique(self._labels)
|
||||
|
||||
@property
|
||||
def num_unique(self):
|
||||
return len(self.unique_labels)
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
unique, counts = torch.unique(self._labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
return dict(zip(unique.tolist(), counts.tolist()))
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
def extra_repr(self):
|
||||
r = f"num_labels: {self.num_labels}, num_unique: {self.num_unique}"
|
||||
if len(self.distribution) < 11: # avoid lengthy representations
|
||||
d = self.distribution
|
||||
unique, counts = list(d.keys()), list(d.values())
|
||||
r += f", unique: {unique}, counts: {counts}"
|
||||
return r
|
||||
|
||||
def __len__(self):
|
||||
return self.num_labels
|
||||
|
||||
|
||||
class Labels(AbstractLabels):
|
||||
"""A set of standalone labels."""
|
||||
def __init__(self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||
super().__init__()
|
||||
self.add_labels(distribution, initializer)
|
||||
|
||||
def add_labels(
|
||||
self,
|
||||
distribution: Union[dict, tuple, list],
|
||||
initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||
"""Generate and add new labels."""
|
||||
assert validate_initializer(initializer, AbstractLabelsInitializer)
|
||||
_labels, new_labels = gencat(self, "_labels", initializer,
|
||||
distribution)
|
||||
self._register_labels(_labels)
|
||||
return new_labels
|
||||
|
||||
def remove_labels(self, indices):
|
||||
"""Remove labels at specified indices."""
|
||||
_labels, mask = removeind(self, "_labels", indices)
|
||||
self._register_labels(_labels)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the labels."""
|
||||
return self._labels
|
||||
|
||||
|
||||
class LabeledComponents(AbstractComponents):
|
||||
"""A set of adaptable components and corresponding unadaptable labels."""
|
||||
def __init__(
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
components_initializer: AbstractComponentsInitializer,
|
||||
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||
super().__init__()
|
||||
self.add_components(distribution, components_initializer,
|
||||
labels_initializer)
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
unique, counts = torch.unique(self._labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
return dict(zip(unique.tolist(), counts.tolist()))
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.distribution.keys())
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
"""Tensor containing the component labels."""
|
||||
return self._labels.cpu()
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
def add_components(
|
||||
self,
|
||||
distribution,
|
||||
components_initializer,
|
||||
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||
"""Generate and add new components and labels."""
|
||||
assert validate_initializer(components_initializer,
|
||||
AbstractComponentsInitializer)
|
||||
assert validate_initializer(labels_initializer,
|
||||
AbstractLabelsInitializer)
|
||||
cikwargs = get_cikwargs(components_initializer, distribution)
|
||||
_components, new_components = gencat(self, "_components",
|
||||
components_initializer,
|
||||
**cikwargs)
|
||||
_labels, new_labels = gencat(self, "_labels", labels_initializer,
|
||||
distribution)
|
||||
self._register_components(_components)
|
||||
self._register_labels(_labels)
|
||||
return new_components, new_labels
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components and labels at specified indices."""
|
||||
_components, mask = removeind(self, "_components", indices)
|
||||
_labels, mask = removeind(self, "_labels", indices)
|
||||
self._register_components(_components)
|
||||
self._register_labels(_labels)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components parameter Tensor and labels."""
|
||||
return self._components, self._labels
|
||||
|
||||
|
||||
class Reasonings(torch.nn.Module):
|
||||
"""A set of standalone reasoning matrices.
|
||||
|
||||
The `reasonings` tensor is of shape [num_components, num_classes, 2].
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
initializer:
|
||||
AbstractReasoningsInitializer = RandomReasoningsInitializer()):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._reasonings.shape[1]
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Tensor containing the reasoning matrices."""
|
||||
return self._reasonings.detach().cpu()
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_buffer("_reasonings", reasonings)
|
||||
|
||||
def add_reasonings(
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
initializer:
|
||||
AbstractReasoningsInitializer = RandomReasoningsInitializer()):
|
||||
"""Generate and add new reasonings."""
|
||||
assert validate_initializer(initializer, AbstractReasoningsInitializer)
|
||||
_reasonings, new_reasonings = gencat(self, "_reasonings", initializer,
|
||||
distribution)
|
||||
self._register_reasonings(_reasonings)
|
||||
return new_reasonings
|
||||
|
||||
def remove_reasonings(self, indices):
|
||||
"""Remove reasonings at specified indices."""
|
||||
_reasonings, mask = removeind(self, "_reasonings", indices)
|
||||
self._register_reasonings(_reasonings)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the reasonings."""
|
||||
return self._reasonings
|
||||
|
||||
|
||||
class ReasoningComponents(AbstractComponents):
|
||||
r"""A set of components and a corresponding adapatable reasoning matrices.
|
||||
|
||||
Every component has its own reasoning matrix.
|
||||
|
||||
A reasoning matrix is an Nx2 matrix, where N is the number of classes. The
|
||||
first element is called positive reasoning :math:`p`, the second negative
|
||||
reasoning :math:`n`. A components can reason in favour (positive) of a
|
||||
class, against (negative) a class or not at all (neutral).
|
||||
|
||||
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
|
||||
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
|
||||
three element probability distribution.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
components_initializer: AbstractComponentsInitializer,
|
||||
reasonings_initializer:
|
||||
AbstractReasoningsInitializer = PurePositiveReasoningsInitializer()):
|
||||
super().__init__()
|
||||
self.add_components(distribution, components_initializer,
|
||||
reasonings_initializer)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._reasonings.shape[1]
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Tensor containing the reasoning matrices."""
|
||||
return self._reasonings.detach().cpu()
|
||||
|
||||
@property
|
||||
def reasoning_matrices(self):
|
||||
"""Reasoning matrices for each class."""
|
||||
with torch.no_grad():
|
||||
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||
pk = A
|
||||
nk = (1 - pk) * B
|
||||
ik = 1 - pk - nk
|
||||
matrices = torch.stack([pk, nk, ik], dim=-1).permute(1, 2, 0)
|
||||
return matrices.cpu()
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_parameter("_reasonings", Parameter(reasonings))
|
||||
|
||||
def add_components(self, distribution, components_initializer,
|
||||
reasonings_initializer: AbstractReasoningsInitializer):
|
||||
"""Generate and add new components and reasonings."""
|
||||
assert validate_initializer(components_initializer,
|
||||
AbstractComponentsInitializer)
|
||||
assert validate_initializer(reasonings_initializer,
|
||||
AbstractReasoningsInitializer)
|
||||
cikwargs = get_cikwargs(components_initializer, distribution)
|
||||
_components, new_components = gencat(self, "_components",
|
||||
components_initializer,
|
||||
**cikwargs)
|
||||
_reasonings, new_reasonings = gencat(self, "_reasonings",
|
||||
reasonings_initializer,
|
||||
distribution)
|
||||
self._register_components(_components)
|
||||
self._register_reasonings(_reasonings)
|
||||
return new_components, new_reasonings
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components and reasonings at specified indices."""
|
||||
_components, mask = removeind(self, "_components", indices)
|
||||
_reasonings, mask = removeind(self, "_reasonings", indices)
|
||||
self._register_components(_components)
|
||||
self._register_reasonings(_reasonings)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components and reasonings."""
|
||||
return self._components, self._reasonings
|
98
prototorch/core/distances.py
Normal file
98
prototorch/core/distances.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""ProtoTorch distances"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
|
||||
|
||||
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
||||
|
||||
**Alias:**
|
||||
``prototorch.functions.distances.sed``
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
expanded_x = x.unsqueeze(dim=1)
|
||||
batchwise_difference = y - expanded_x
|
||||
differences_raised = torch.pow(batchwise_difference, 2)
|
||||
distances = torch.sum(differences_raised, axis=2)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance(x, y):
|
||||
r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
|
||||
|
||||
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
||||
|
||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
distances_raised = squared_euclidean_distance(x, y)
|
||||
distances = torch.sqrt(distances_raised)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_v2(x, y):
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
diff = y - x.unsqueeze(1)
|
||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||
# batch diagonal. See:
|
||||
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
||||
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
||||
# print(f"{distances.shape=}") # (nx, ny)
|
||||
return distances
|
||||
|
||||
|
||||
def lpnorm_distance(x, y, p):
|
||||
r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
||||
Also known as Minkowski distance.
|
||||
|
||||
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
||||
|
||||
Calls ``torch.cdist``
|
||||
|
||||
:param p: p parameter of the lp norm
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
distances = torch.cdist(x, y, p=p)
|
||||
return distances
|
||||
|
||||
|
||||
def omega_distance(x, y, omega):
|
||||
r"""Omega distance.
|
||||
|
||||
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
||||
|
||||
:param `torch.tensor` omega: Two dimensional matrix
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
projected_x = x @ omega
|
||||
projected_y = y @ omega
|
||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||
return distances
|
||||
|
||||
|
||||
def lomega_distance(x, y, omegas):
|
||||
r"""Localized Omega distance.
|
||||
|
||||
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
||||
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
projected_x = x @ omegas
|
||||
projected_y = torch.diagonal(y @ omegas).T
|
||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||
batchwise_difference = expanded_y - projected_x
|
||||
differences_squared = batchwise_difference**2
|
||||
distances = torch.sum(differences_squared, dim=2)
|
||||
distances = distances.permute(1, 0)
|
||||
return distances
|
||||
|
||||
|
||||
# Aliases
|
||||
sed = squared_euclidean_distance
|
494
prototorch/core/initializers.py
Normal file
494
prototorch/core/initializers.py
Normal file
@@ -0,0 +1,494 @@
|
||||
"""ProtoTorch code initializers"""
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import (
|
||||
Callable,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import parse_data_arg, parse_distribution
|
||||
|
||||
|
||||
# Components
|
||||
class AbstractComponentsInitializer(ABC):
|
||||
"""Abstract class for all components initializers."""
|
||||
...
|
||||
|
||||
|
||||
class LiteralCompInitializer(AbstractComponentsInitializer):
|
||||
"""'Generate' the provided components.
|
||||
|
||||
Use this to 'generate' pre-initialized components elsewhere.
|
||||
|
||||
"""
|
||||
def __init__(self, components):
|
||||
self.components = components
|
||||
|
||||
def generate(self, num_components: int = 0):
|
||||
"""Ignore `num_components` and simply return `self.components`."""
|
||||
if not isinstance(self.components, torch.Tensor):
|
||||
wmsg = f"Converting components to {torch.Tensor}..."
|
||||
warnings.warn(wmsg)
|
||||
self.components = torch.Tensor(self.components)
|
||||
return self.components
|
||||
|
||||
|
||||
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all dimension-aware components initializers."""
|
||||
def __init__(self, shape: Union[Iterable, int]):
|
||||
if isinstance(shape, Iterable):
|
||||
self.component_shape = tuple(shape)
|
||||
else:
|
||||
self.component_shape = (shape, )
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, num_components: int):
|
||||
...
|
||||
|
||||
|
||||
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate zeros corresponding to the components shape."""
|
||||
def generate(self, num_components: int):
|
||||
components = torch.zeros((num_components, ) + self.component_shape)
|
||||
return components
|
||||
|
||||
|
||||
class OnesCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate ones corresponding to the components shape."""
|
||||
def generate(self, num_components: int):
|
||||
components = torch.ones((num_components, ) + self.component_shape)
|
||||
return components
|
||||
|
||||
|
||||
class FillValueCompInitializer(OnesCompInitializer):
|
||||
"""Generate components with the provided `fill_value`."""
|
||||
def __init__(self, shape, fill_value: float = 1.0):
|
||||
super().__init__(shape)
|
||||
self.fill_value = fill_value
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = ones.fill_(self.fill_value)
|
||||
return components
|
||||
|
||||
|
||||
class UniformCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a continuous uniform distribution."""
|
||||
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = self.scale * ones.uniform_(self.minimum, self.maximum)
|
||||
return components
|
||||
|
||||
|
||||
class RandomNormalCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a standard normal distribution."""
|
||||
def __init__(self, shape, shift=0.0, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.shift = shift
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = self.scale * (torch.randn_like(ones) + self.shift)
|
||||
return components
|
||||
|
||||
|
||||
class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all data-aware components initializers.
|
||||
|
||||
Components generated by data-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
`data` has to be a torch tensor.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
noise: float = 0.0,
|
||||
transform: Callable = torch.nn.Identity()):
|
||||
self.data = data
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
|
||||
def generate_end_hook(self, samples):
|
||||
drift = torch.rand_like(samples) * self.noise
|
||||
components = self.transform(samples + drift)
|
||||
return components
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, num_components: int):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""'Generate' the components from the provided data."""
|
||||
def generate(self, num_components: int = 0):
|
||||
"""Ignore `num_components` and simply return transformed `self.data`."""
|
||||
components = self.generate_end_hook(self.data)
|
||||
return components
|
||||
|
||||
|
||||
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""Generate components by uniformly sampling from the provided data."""
|
||||
def generate(self, num_components: int):
|
||||
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
||||
samples = self.data[indices]
|
||||
components = self.generate_end_hook(samples)
|
||||
return components
|
||||
|
||||
|
||||
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""Generate components by computing the mean of the provided data."""
|
||||
def generate(self, num_components: int):
|
||||
mean = self.data.mean(dim=0)
|
||||
repeat_dim = [num_components] + [1] * len(mean.shape)
|
||||
samples = mean.repeat(repeat_dim)
|
||||
components = self.generate_end_hook(samples)
|
||||
return components
|
||||
|
||||
|
||||
class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all class-aware components initializers.
|
||||
|
||||
Components generated by class-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
`data` could be a torch Dataset or DataLoader or a list/tuple of data and
|
||||
target tensors.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data,
|
||||
noise: float = 0.0,
|
||||
transform: Callable = torch.nn.Identity()):
|
||||
self.data, self.targets = parse_data_arg(data)
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
def generate_end_hook(self, samples):
|
||||
drift = torch.rand_like(samples) * self.noise
|
||||
components = self.transform(samples + drift)
|
||||
return components
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.targets
|
||||
|
||||
|
||||
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
||||
"""'Generate' components from provided data and requested distribution."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
"""Ignore `distribution` and simply return transformed `self.data`."""
|
||||
components = self.generate_end_hook(self.data)
|
||||
return components
|
||||
|
||||
|
||||
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
||||
"""Abstract class for all stratified components initializers."""
|
||||
@property
|
||||
@abstractmethod
|
||||
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
|
||||
...
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
components = torch.tensor([])
|
||||
for k, v in distribution.items():
|
||||
stratified_data = self.data[self.targets == k]
|
||||
initializer = self.subinit_type(
|
||||
stratified_data,
|
||||
noise=self.noise,
|
||||
transform=self.transform,
|
||||
)
|
||||
samples = initializer.generate(num_components=v)
|
||||
components = torch.cat([components, samples])
|
||||
return components
|
||||
|
||||
|
||||
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
||||
"""Generate components using stratified sampling from the provided data."""
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return SelectionCompInitializer
|
||||
|
||||
|
||||
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
||||
"""Generate components at stratified means of the provided data."""
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return MeanCompInitializer
|
||||
|
||||
|
||||
# Labels
|
||||
class AbstractLabelsInitializer(ABC):
|
||||
"""Abstract class for all labels initializers."""
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
|
||||
|
||||
class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
||||
"""'Generate' the provided labels.
|
||||
|
||||
Use this to 'generate' pre-initialized labels elsewhere.
|
||||
|
||||
"""
|
||||
def __init__(self, labels):
|
||||
self.labels = labels
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
"""Ignore `distribution` and simply return `self.labels`.
|
||||
|
||||
Convert to long tensor, if necessary.
|
||||
"""
|
||||
labels = self.labels
|
||||
if not isinstance(labels, torch.LongTensor):
|
||||
wmsg = f"Converting labels to {torch.LongTensor}..."
|
||||
warnings.warn(wmsg)
|
||||
labels = torch.LongTensor(labels)
|
||||
return labels
|
||||
|
||||
|
||||
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
||||
"""'Generate' the labels from a torch Dataset."""
|
||||
def __init__(self, data):
|
||||
self.data, self.targets = parse_data_arg(data)
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
"""Ignore `num_components` and simply return `self.targets`."""
|
||||
return self.targets
|
||||
|
||||
|
||||
class LabelsInitializer(AbstractLabelsInitializer):
|
||||
"""Generate labels from `distribution`."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
labels_list = []
|
||||
for k, v in distribution.items():
|
||||
labels_list.extend([k] * v)
|
||||
labels = torch.LongTensor(labels_list)
|
||||
return labels
|
||||
|
||||
|
||||
class OneHotLabelsInitializer(LabelsInitializer):
|
||||
"""Generate one-hot-encoded labels from `distribution`."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_classes = len(distribution.keys())
|
||||
# this breaks if class labels are not [0,...,nclasses]
|
||||
labels = torch.eye(num_classes)[super().generate(distribution)]
|
||||
return labels
|
||||
|
||||
|
||||
# Reasonings
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
def __init__(self, components_first: bool = True):
|
||||
self.components_first = components_first
|
||||
|
||||
def compute_shape(self, distribution):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_components = sum(distribution.values())
|
||||
num_classes = len(distribution.keys())
|
||||
return (num_components, num_classes, 2)
|
||||
|
||||
def generate_end_hook(self, reasonings):
|
||||
if not self.components_first:
|
||||
reasonings = reasonings.permute(2, 1, 0)
|
||||
return reasonings
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
|
||||
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""'Generate' the provided reasonings.
|
||||
|
||||
Use this to 'generate' pre-initialized reasonings elsewhere.
|
||||
|
||||
"""
|
||||
def __init__(self, reasonings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.reasonings = reasonings
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
"""Ignore `distributuion` and simply return self.reasonings."""
|
||||
reasonings = self.reasonings
|
||||
if not isinstance(reasonings, torch.Tensor):
|
||||
wmsg = f"Converting reasonings to {torch.Tensor}..."
|
||||
warnings.warn(wmsg)
|
||||
reasonings = torch.Tensor(reasonings)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with zeros."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
reasonings = torch.zeros(*shape)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with ones."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
reasonings = torch.ones(*shape)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are randomly initialized."""
|
||||
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Each component reasons positively for exactly one class."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
||||
A = OneHotLabelsInitializer().generate(distribution)
|
||||
B = torch.zeros(num_components, num_classes)
|
||||
reasonings = torch.stack([A, B], dim=-1)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
# Transforms
|
||||
class AbstractTransformInitializer(ABC):
|
||||
"""Abstract class for all transform initializers."""
|
||||
...
|
||||
|
||||
|
||||
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
||||
"""Abstract class for all linear transform initializers."""
|
||||
def __init__(self, out_dim_first: bool = False):
|
||||
self.out_dim_first = out_dim_first
|
||||
|
||||
def generate_end_hook(self, weights):
|
||||
if self.out_dim_first:
|
||||
weights = weights.permute(1, 0)
|
||||
return weights
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
|
||||
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with zeros."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.zeros(in_dim, out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with ones."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.ones(in_dim, out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with the largest possible identity matrix."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.zeros(in_dim, out_dim)
|
||||
I = torch.eye(min(in_dim, out_dim))
|
||||
weights[:I.shape[0], :I.shape[1]] = I
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
||||
"""Abstract class for all data-aware linear transform initializers."""
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
noise: float = 0.0,
|
||||
transform: Callable = torch.nn.Identity()):
|
||||
self.data = data
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
|
||||
def generate_end_hook(self, weights: torch.Tensor):
|
||||
drift = torch.rand_like(weights) * self.noise
|
||||
weights = self.transform(weights + drift)
|
||||
if self.out_dim_first:
|
||||
weights = weights.permute(1, 0)
|
||||
return weights
|
||||
|
||||
|
||||
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
|
||||
"""Initialize a matrix with Eigenvectors from the data."""
|
||||
@abstractmethod
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
# Aliases - Components
|
||||
CACI = ClassAwareCompInitializer
|
||||
DACI = DataAwareCompInitializer
|
||||
FVCI = FillValueCompInitializer
|
||||
LCI = LiteralCompInitializer
|
||||
MCI = MeanCompInitializer
|
||||
OCI = OnesCompInitializer
|
||||
RNCI = RandomNormalCompInitializer
|
||||
SCI = SelectionCompInitializer
|
||||
SMCI = StratifiedMeanCompInitializer
|
||||
SSCI = StratifiedSelectionCompInitializer
|
||||
UCI = UniformCompInitializer
|
||||
ZCI = ZerosCompInitializer
|
||||
|
||||
# Aliases - Labels
|
||||
DLI = DataAwareLabelsInitializer
|
||||
LI = LabelsInitializer
|
||||
LLI = LiteralLabelsInitializer
|
||||
OHLI = OneHotLabelsInitializer
|
||||
|
||||
# Aliases - Reasonings
|
||||
LRI = LiteralReasoningsInitializer
|
||||
ORI = OnesReasoningsInitializer
|
||||
PPRI = PurePositiveReasoningsInitializer
|
||||
RRI = RandomReasoningsInitializer
|
||||
ZRI = ZerosReasoningsInitializer
|
||||
|
||||
# Aliases - Transforms
|
||||
Eye = EyeTransformInitializer
|
||||
OLTI = OnesLinearTransformInitializer
|
||||
ZLTI = ZerosLinearTransformInitializer
|
||||
PCALTI = PCALinearTransformInitializer
|
171
prototorch/core/losses.py
Normal file
171
prototorch/core/losses.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""ProtoTorch losses"""
|
||||
|
||||
import torch
|
||||
|
||||
from ..nn.activations import get_activation
|
||||
|
||||
|
||||
# Helpers
|
||||
def _get_matcher(targets, labels):
|
||||
"""Returns a boolean tensor."""
|
||||
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
num_classes = targets.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
return matcher
|
||||
|
||||
|
||||
def _get_dp_dm(distances, targets, plabels, with_indices=False):
|
||||
"""Returns the d+ and d- values for a batch of distances."""
|
||||
matcher = _get_matcher(targets, plabels)
|
||||
not_matcher = torch.bitwise_not(matcher)
|
||||
|
||||
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||
d_matching = torch.where(matcher, distances, inf)
|
||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||
dp = torch.min(d_matching, dim=-1, keepdim=True)
|
||||
dm = torch.min(d_unmatching, dim=-1, keepdim=True)
|
||||
if with_indices:
|
||||
return dp, dm
|
||||
return dp.values, dm.values
|
||||
|
||||
|
||||
# GLVQ
|
||||
def glvq_loss(distances, target_labels, prototype_labels):
|
||||
"""GLVQ loss function with support for one-hot labels."""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = (dp - dm) / (dp + dm)
|
||||
return mu
|
||||
|
||||
|
||||
def lvq1_loss(distances, target_labels, prototype_labels):
|
||||
"""LVQ1 loss function with support for one-hot labels.
|
||||
|
||||
See Section 4 [Sado&Yamada]
|
||||
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||
"""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = dp
|
||||
mu[dp > dm] = -dm[dp > dm]
|
||||
return mu
|
||||
|
||||
|
||||
def lvq21_loss(distances, target_labels, prototype_labels):
|
||||
"""LVQ2.1 loss function with support for one-hot labels.
|
||||
|
||||
See Section 4 [Sado&Yamada]
|
||||
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||
"""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = dp - dm
|
||||
|
||||
return mu
|
||||
|
||||
|
||||
# Probabilistic
|
||||
def _get_class_probabilities(probabilities, targets, prototype_labels):
|
||||
# Create Label Mapping
|
||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||
|
||||
target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist())))
|
||||
|
||||
whole = probabilities.sum(dim=1)
|
||||
correct = probabilities[torch.arange(len(probabilities)), target_indices]
|
||||
wrong = whole - correct
|
||||
|
||||
return whole, correct, wrong
|
||||
|
||||
|
||||
def nllr_loss(probabilities, targets, prototype_labels):
|
||||
"""Compute the Negative Log-Likelihood Ratio loss."""
|
||||
_, correct, wrong = _get_class_probabilities(probabilities, targets,
|
||||
prototype_labels)
|
||||
|
||||
likelihood = correct / wrong
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return -1.0 * log_likelihood
|
||||
|
||||
|
||||
def rslvq_loss(probabilities, targets, prototype_labels):
|
||||
"""Compute the Robust Soft Learning Vector Quantization (RSLVQ) loss."""
|
||||
whole, correct, _ = _get_class_probabilities(probabilities, targets,
|
||||
prototype_labels)
|
||||
|
||||
likelihood = correct / whole
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return -1.0 * log_likelihood
|
||||
|
||||
|
||||
def margin_loss(y_pred, y_true, margin=0.3):
|
||||
"""Compute the margin loss."""
|
||||
dp = torch.sum(y_true * y_pred, dim=-1)
|
||||
dm = torch.max(y_pred - y_true, dim=-1).values
|
||||
return torch.nn.functional.relu(dm - dp + margin)
|
||||
|
||||
|
||||
class GLVQLoss(torch.nn.Module):
|
||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
self.beta = torch.tensor(beta)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
||||
|
||||
|
||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||
def __init__(self,
|
||||
margin=0.3,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction="mean"):
|
||||
super().__init__(size_average, reduce, reduction)
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
return margin_loss(y_pred, y_true, self.margin)
|
||||
|
||||
|
||||
class NeuralGasEnergy(torch.nn.Module):
|
||||
def __init__(self, lm, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.lm = lm
|
||||
|
||||
def forward(self, d):
|
||||
order = torch.argsort(d, dim=1)
|
||||
ranks = torch.argsort(order, dim=1)
|
||||
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
|
||||
|
||||
return cost, order
|
||||
|
||||
def extra_repr(self):
|
||||
return f"lambda: {self.lm}"
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, lm):
|
||||
return torch.exp(-rankings / lm)
|
||||
|
||||
|
||||
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
||||
def __init__(self, topology_layer, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topology_layer = topology_layer
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, topology):
|
||||
winner = rankings[:, 0]
|
||||
|
||||
weights = torch.zeros_like(rankings, dtype=torch.float)
|
||||
weights[torch.arange(rankings.shape[0]), winner] = 1.0
|
||||
|
||||
neighbours = topology.get_neighbours(winner)
|
||||
|
||||
weights[neighbours] = 0.1
|
||||
|
||||
return weights
|
104
prototorch/core/pooling.py
Normal file
104
prototorch/core/pooling.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""ProtoTorch pooling"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def stratify_with(values: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
fn: Callable,
|
||||
fill_value: float = 0.0) -> (torch.Tensor):
|
||||
"""Apply an arbitrary stratification strategy on the columns on `values`.
|
||||
|
||||
The outputs correspond to sorted labels.
|
||||
"""
|
||||
clabels = torch.unique(labels, dim=0, sorted=True)
|
||||
num_classes = clabels.size()[0]
|
||||
if values.size()[1] == num_classes:
|
||||
# skip if stratification is trivial
|
||||
return values
|
||||
batch_size = values.size()[0]
|
||||
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
||||
filler = torch.full_like(values.T, fill_value=fill_value)
|
||||
for i, cl in enumerate(clabels):
|
||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
cdists = torch.where(matcher, values.T, filler).T
|
||||
winning_values[i] = fn(cdists)
|
||||
if labels.ndim == 2:
|
||||
# Transpose to return with `batch_size` first and
|
||||
# reverse the columns to fix the ordering of the classes
|
||||
return torch.flip(winning_values.T, dims=(1, ))
|
||||
|
||||
return winning_values.T # return with `batch_size` first
|
||||
|
||||
|
||||
def stratified_sum_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise sum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
|
||||
fill_value=0.0)
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_min_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise minimum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
|
||||
fill_value=float("inf"))
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_max_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise maximum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
|
||||
fill_value=-1.0 * float("inf"))
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_prod_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise maximum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
|
||||
fill_value=1.0)
|
||||
return winning_values
|
||||
|
||||
|
||||
class StratifiedSumPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_sum_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedProdPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_prod_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMinPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_min_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMaxPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_max_pooling(values, labels)
|
31
prototorch/core/similarities.py
Normal file
31
prototorch/core/similarities.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""ProtoTorch similarities."""
|
||||
|
||||
import torch
|
||||
|
||||
from .distances import euclidean_distance
|
||||
|
||||
|
||||
def gaussian(x, variance=1.0):
|
||||
return torch.exp(-(x * x) / (2 * variance))
|
||||
|
||||
|
||||
def euclidean_similarity(x, y, variance=1.0):
|
||||
distances = euclidean_distance(x, y)
|
||||
similarities = gaussian(distances, variance)
|
||||
return similarities
|
||||
|
||||
|
||||
def cosine_similarity(x, y):
|
||||
"""Compute the cosine similarity between :math:`x` and :math:`y`.
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
norm_x = x.pow(2).sum(1).sqrt()
|
||||
norm_y = y.pow(2).sum(1).sqrt()
|
||||
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
||||
epsilon = torch.finfo(norm_mat.dtype).eps
|
||||
norm_mat.clamp_(min=epsilon)
|
||||
similarities = (x @ y.T) / norm_mat
|
||||
return similarities
|
43
prototorch/core/transforms.py
Normal file
43
prototorch/core/transforms.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""ProtoTorch transforms"""
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .initializers import (
|
||||
AbstractLinearTransformInitializer,
|
||||
EyeTransformInitializer,
|
||||
)
|
||||
|
||||
|
||||
class LinearTransform(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
initializer:
|
||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
||||
super().__init__()
|
||||
self.set_weights(in_dim, out_dim, initializer)
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
return self._weights.detach().cpu()
|
||||
|
||||
def _register_weights(self, weights):
|
||||
self.register_parameter("_weights", Parameter(weights))
|
||||
|
||||
def set_weights(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
initializer:
|
||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
||||
weights = initializer.generate(in_dim, out_dim)
|
||||
self._register_weights(weights)
|
||||
|
||||
def forward(self, x):
|
||||
return x @ self.weights.T
|
||||
|
||||
|
||||
# Aliases
|
||||
Omega = LinearTransform
|
Reference in New Issue
Block a user