refactor(api)!: merge the new api changes into dev
BREAKING CHANGE: remove the following `prototorch/functions/*` `prototorch/components/*` `prototorch/modules/*` BREAKING CHANGE: move `initializers` into the `prototorch.initializers` namespace from the `prototorch.components` namespace BREAKING CHANGE: `functions` and `modules` and moved into `core` and `nn`
This commit is contained in:
commit
5dc66494ea
@ -3,8 +3,8 @@ current_version = 0.5.1
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||
serialize =
|
||||
{major}.{minor}.{patch}
|
||||
serialize = {major}.{minor}.{patch}
|
||||
message = bump: {current_version} → {new_version}
|
||||
|
||||
[bumpversion:file:setup.py]
|
||||
|
||||
|
16
.gitignore
vendored
16
.gitignore
vendored
@ -129,14 +129,6 @@ dmypy.json
|
||||
|
||||
# End of https://www.gitignore.io/api/python
|
||||
|
||||
# ProtoFlow
|
||||
core
|
||||
checkpoint
|
||||
logs/
|
||||
saved_weights/
|
||||
scratch*
|
||||
|
||||
|
||||
# Created by https://www.gitignore.io/api/visualstudiocode
|
||||
# Edit at https://www.gitignore.io/?templates=visualstudiocode
|
||||
|
||||
@ -154,5 +146,13 @@ scratch*
|
||||
# End of https://www.gitignore.io/api/visualstudiocode
|
||||
.vscode/
|
||||
|
||||
# Vim
|
||||
*~
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Artifacts created by ProtoTorch
|
||||
reports
|
||||
artifacts
|
||||
examples/_*.py
|
||||
examples/_*.ipynb
|
||||
|
@ -23,19 +23,19 @@ repos:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: 'v0.902'
|
||||
rev: v0.902
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: prototorch
|
||||
additional_dependencies: [types-pkg_resources]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: 'v0.31.0' # Use the sha / tag you want to point at
|
||||
rev: v0.31.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
|
||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||
rev: v1.9.0 # Use the ref you want to point at
|
||||
rev: v1.9.0
|
||||
hooks:
|
||||
- id: python-use-type-annotations
|
||||
- id: python-no-log-warn
|
||||
@ -47,8 +47,8 @@ repos:
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/jorisroovers/gitlint
|
||||
rev: "v0.15.1"
|
||||
- repo: https://github.com/si-cim/gitlint
|
||||
rev: v0.15.2-unofficial
|
||||
hooks:
|
||||
- id: gitlint
|
||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
||||
|
7
.remarkrc
Normal file
7
.remarkrc
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"plugins": [
|
||||
"remark-preset-lint-recommended",
|
||||
["remark-lint-list-item-indent", false],
|
||||
["no-emphasis-as-header", false]
|
||||
]
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
dist: bionic
|
||||
sudo: false
|
||||
language: python
|
||||
python: 3.8
|
||||
python: 3.9
|
||||
cache:
|
||||
directories:
|
||||
- "$HOME/.cache/pip"
|
||||
|
10
README.md
10
README.md
@ -51,14 +51,20 @@ that link not work try <https://prototorch.readthedocs.io/en/latest/>.
|
||||
## Contribution
|
||||
|
||||
This repository contains definition for [git hooks](https://githooks.com).
|
||||
[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
|
||||
Please install the hooks by running
|
||||
[Pre-commit](https://pre-commit.com) is automatically installed as development
|
||||
dependency with prototorch or you can install it manually with `pip install
|
||||
pre-commit`.
|
||||
|
||||
Please install the hooks by running:
|
||||
```bash
|
||||
pre-commit install
|
||||
pre-commit install --hook-type commit-msg
|
||||
```
|
||||
before creating the first commit.
|
||||
|
||||
The commit will fail if the commit message does not follow the specification
|
||||
provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
|
||||
|
||||
## Bibtex
|
||||
|
||||
If you would like to cite the package, please use this:
|
||||
|
96
examples/cbc_iris.py
Normal file
96
examples/cbc_iris.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""ProtoTorch CBC example using 2D Iris data."""
|
||||
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
class CBC(torch.nn.Module):
|
||||
def __init__(self, data, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.components_layer = pt.components.ReasoningComponents(
|
||||
distribution=[2, 1, 2],
|
||||
components_initializer=pt.initializers.SSCI(data, noise=0.1),
|
||||
reasonings_initializer=pt.initializers.PPRI(components_first=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
components, reasonings = self.components_layer()
|
||||
sims = pt.similarities.euclidean_similarity(x, components)
|
||||
probs = pt.competitions.cbcc(sims, reasonings)
|
||||
return probs
|
||||
|
||||
|
||||
class VisCBC2D():
|
||||
def __init__(self, model, data):
|
||||
self.model = model
|
||||
self.x_train, self.y_train = pt.utils.parse_data_arg(data)
|
||||
self.title = "Components Visualization"
|
||||
self.fig = plt.figure(self.title)
|
||||
self.border = 0.1
|
||||
self.resolution = 100
|
||||
self.cmap = "viridis"
|
||||
|
||||
def on_epoch_end(self):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
_components = self.model.components_layer._components.detach()
|
||||
ax = self.fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(self.title)
|
||||
ax.axis("off")
|
||||
ax.scatter(
|
||||
x_train[:, 0],
|
||||
x_train[:, 1],
|
||||
c=y_train,
|
||||
cmap=self.cmap,
|
||||
edgecolor="k",
|
||||
marker="o",
|
||||
s=30,
|
||||
)
|
||||
ax.scatter(
|
||||
_components[:, 0],
|
||||
_components[:, 1],
|
||||
c="w",
|
||||
cmap=self.cmap,
|
||||
edgecolor="k",
|
||||
marker="D",
|
||||
s=50,
|
||||
)
|
||||
x = torch.vstack((x_train, _components))
|
||||
mesh_input, xx, yy = pt.utils.mesh2d(x, self.border, self.resolution)
|
||||
with torch.no_grad():
|
||||
y_pred = self.model(
|
||||
torch.Tensor(mesh_input).type_as(_components)).argmax(1)
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
plt.pause(0.2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
|
||||
|
||||
model = CBC(train_ds)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
criterion = pt.losses.MarginLoss(margin=0.1)
|
||||
vis = VisCBC2D(model, train_ds)
|
||||
|
||||
for epoch in range(200):
|
||||
correct = 0.0
|
||||
for x, y in train_loader:
|
||||
y_oh = torch.eye(3)[y]
|
||||
y_pred = model(x)
|
||||
loss = criterion(y_pred, y_oh).mean(0)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
correct += (y_pred.argmax(1) == y).float().sum(0)
|
||||
|
||||
acc = 100 * correct / len(train_ds)
|
||||
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
||||
vis.on_epoch_end()
|
@ -1,39 +1,35 @@
|
||||
"""This example script shows the usage of the new components architecture.
|
||||
|
||||
Serialization/deserialization also works as expected.
|
||||
|
||||
"""
|
||||
|
||||
# DATASET
|
||||
import torch
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
scaler = StandardScaler()
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
scaler.fit(x_train)
|
||||
x_train = scaler.transform(x_train)
|
||||
import prototorch as pt
|
||||
|
||||
x_train = torch.Tensor(x_train)
|
||||
y_train = torch.Tensor(y_train)
|
||||
num_classes = len(torch.unique(y_train))
|
||||
ds = pt.datasets.Iris()
|
||||
|
||||
# CREATE NEW COMPONENTS
|
||||
from prototorch.components import *
|
||||
from prototorch.components.initializers import *
|
||||
|
||||
unsupervised = Components(6, SelectionInitializer(x_train))
|
||||
unsupervised = pt.components.Components(
|
||||
6,
|
||||
initializer=pt.initializers.ZCI(2),
|
||||
)
|
||||
print(unsupervised())
|
||||
|
||||
prototypes = LabeledComponents(
|
||||
(3, 2), StratifiedSelectionInitializer(x_train, y_train))
|
||||
prototypes = pt.components.LabeledComponents(
|
||||
(3, 2),
|
||||
components_initializer=pt.initializers.SSCI(ds),
|
||||
)
|
||||
print(prototypes())
|
||||
|
||||
components = ReasoningComponents(
|
||||
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
|
||||
print(components())
|
||||
components = pt.components.ReasoningComponents(
|
||||
(3, 2),
|
||||
components_initializer=pt.initializers.SSCI(ds),
|
||||
reasonings_initializer=pt.initializers.PPRI(),
|
||||
)
|
||||
print(prototypes())
|
||||
|
||||
# TEST SERIALIZATION
|
||||
# Test Serialization
|
||||
import io
|
||||
|
||||
save = io.BytesIO()
|
||||
@ -41,25 +37,20 @@ torch.save(unsupervised, save)
|
||||
save.seek(0)
|
||||
serialized_unsupervised = torch.load(save)
|
||||
|
||||
assert torch.all(unsupervised.components == serialized_unsupervised.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(unsupervised.components == serialized_unsupervised.components)
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(prototypes, save)
|
||||
save.seek(0)
|
||||
serialized_prototypes = torch.load(save)
|
||||
|
||||
assert torch.all(prototypes.components == serialized_prototypes.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(prototypes.component_labels == serialized_prototypes.
|
||||
component_labels), "Serialization of Components failed."
|
||||
assert torch.all(prototypes.components == serialized_prototypes.components)
|
||||
assert torch.all(prototypes.labels == serialized_prototypes.labels)
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(components, save)
|
||||
save.seek(0)
|
||||
serialized_components = torch.load(save)
|
||||
|
||||
assert torch.all(components.components == serialized_components.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(components.reasonings == serialized_components.reasonings
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(components.components == serialized_components.components)
|
||||
assert torch.all(components.reasonings == serialized_components.reasonings)
|
||||
|
@ -1,21 +1,41 @@
|
||||
"""ProtoTorch package."""
|
||||
"""ProtoTorch package"""
|
||||
|
||||
import pkgutil
|
||||
from typing import List
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from . import components, datasets, functions, modules, utils
|
||||
from .datasets import *
|
||||
from . import (
|
||||
datasets,
|
||||
nn,
|
||||
utils,
|
||||
)
|
||||
from .core import (
|
||||
competitions,
|
||||
components,
|
||||
distances,
|
||||
initializers,
|
||||
losses,
|
||||
pooling,
|
||||
similarities,
|
||||
transforms,
|
||||
)
|
||||
|
||||
# Core Setup
|
||||
__version__ = "0.5.1"
|
||||
|
||||
__all_core__ = [
|
||||
"datasets",
|
||||
"functions",
|
||||
"modules",
|
||||
"competitions",
|
||||
"components",
|
||||
"core",
|
||||
"datasets",
|
||||
"distances",
|
||||
"initializers",
|
||||
"losses",
|
||||
"nn",
|
||||
"pooling",
|
||||
"similarities",
|
||||
"transforms",
|
||||
"utils",
|
||||
]
|
||||
|
||||
|
@ -1,2 +0,0 @@
|
||||
from prototorch.components.components import *
|
||||
from prototorch.components.initializers import *
|
@ -1,270 +0,0 @@
|
||||
"""ProtoTorch components modules."""
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||
ComponentsInitializer,
|
||||
EqualLabelsInitializer,
|
||||
UnequalLabelsInitializer,
|
||||
ZeroReasoningsInitializer)
|
||||
|
||||
from .initializers import parse_data_arg
|
||||
|
||||
|
||||
def get_labels_object(distribution):
|
||||
if isinstance(distribution, dict):
|
||||
if "num_classes" in distribution.keys():
|
||||
labels = EqualLabelsInitializer(
|
||||
distribution["num_classes"],
|
||||
distribution["prototypes_per_class"])
|
||||
else:
|
||||
clabels = list(distribution.keys())
|
||||
dist = list(distribution.values())
|
||||
labels = UnequalLabelsInitializer(dist, clabels)
|
||||
elif isinstance(distribution, tuple):
|
||||
num_classes, prototypes_per_class = distribution
|
||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||
elif isinstance(distribution, list):
|
||||
labels = UnequalLabelsInitializer(distribution)
|
||||
else:
|
||||
msg = f"`distribution` not understood." \
|
||||
f"You have provided: {distribution=}."
|
||||
raise ValueError(msg)
|
||||
return labels
|
||||
|
||||
|
||||
def _precheck_initializer(initializer):
|
||||
if not isinstance(initializer, ComponentsInitializer):
|
||||
emsg = f"`initializer` has to be some subtype of " \
|
||||
f"{ComponentsInitializer}. " \
|
||||
f"You have provided: {initializer=} instead."
|
||||
raise TypeError(emsg)
|
||||
|
||||
|
||||
class LinearMapping(torch.nn.Module):
|
||||
"""LinearMapping is a learnable Mapping Matrix."""
|
||||
def __init__(self,
|
||||
mapping_shape=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_linearmapping=None):
|
||||
super().__init__()
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
if initialized_linearmapping is not None:
|
||||
self._register_mapping(initialized_linearmapping)
|
||||
if num_components is not None or initializer is not None:
|
||||
wmsg = "Arguments ignored while initializing Components"
|
||||
warnings.warn(wmsg)
|
||||
else:
|
||||
self._initialize_mapping(mapping_shape, initializer)
|
||||
|
||||
@property
|
||||
def mapping_shape(self):
|
||||
return self._omega.shape
|
||||
|
||||
def _register_mapping(self, components):
|
||||
self.register_parameter("_omega", Parameter(components))
|
||||
|
||||
def _initialize_mapping(self, mapping_shape, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
_mapping = initializer.generate(mapping_shape)
|
||||
self._register_mapping(_mapping)
|
||||
|
||||
@property
|
||||
def mapping(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._omega.detach()
|
||||
|
||||
def forward(self):
|
||||
return self._omega
|
||||
|
||||
|
||||
class Components(torch.nn.Module):
|
||||
"""Components is a set of learnable Tensors."""
|
||||
def __init__(self,
|
||||
num_components=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
super().__init__()
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
if initialized_components is not None:
|
||||
self._register_components(initialized_components)
|
||||
if num_components is not None or initializer is not None:
|
||||
wmsg = "Arguments ignored while initializing Components"
|
||||
warnings.warn(wmsg)
|
||||
else:
|
||||
self._initialize_components(num_components, initializer)
|
||||
|
||||
@property
|
||||
def num_components(self):
|
||||
return len(self._components)
|
||||
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
|
||||
def _initialize_components(self, num_components, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
_components = initializer.generate(num_components)
|
||||
self._register_components(_components)
|
||||
|
||||
def add_components(self,
|
||||
num=1,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
_components = torch.cat([self._components, initialized_components])
|
||||
else:
|
||||
_precheck_initializer(initializer)
|
||||
_new = initializer.generate(num)
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
|
||||
def remove_components(self, indices=None):
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
self._register_components(_components)
|
||||
return mask
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._components.detach()
|
||||
|
||||
def forward(self):
|
||||
return self._components
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(components): (shape: {tuple(self._components.shape)})"
|
||||
|
||||
|
||||
class LabeledComponents(Components):
|
||||
"""LabeledComponents generate a set of components and a set of labels.
|
||||
|
||||
Every Component has a label assigned.
|
||||
"""
|
||||
def __init__(self,
|
||||
distribution=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
components, component_labels = parse_data_arg(
|
||||
initialized_components)
|
||||
super().__init__(initialized_components=components)
|
||||
self._register_labels(component_labels)
|
||||
else:
|
||||
labels = get_labels_object(distribution)
|
||||
self.initial_distribution = labels.distribution
|
||||
_labels = labels.generate()
|
||||
super().__init__(len(_labels), initializer=initializer)
|
||||
self._register_labels(_labels)
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
clabels, counts = torch.unique(self._labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
return dict(zip(clabels.tolist(), counts.tolist()))
|
||||
|
||||
def _initialize_components(self, num_components, initializer):
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
_precheck_initializer(initializer)
|
||||
_components = initializer.generate(num_components,
|
||||
self.initial_distribution)
|
||||
self._register_components(_components)
|
||||
else:
|
||||
super()._initialize_components(num_components, initializer)
|
||||
|
||||
def add_components(self, distribution, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
|
||||
# Labels
|
||||
labels = get_labels_object(distribution)
|
||||
new_labels = labels.generate()
|
||||
_labels = torch.cat([self._labels, new_labels])
|
||||
self._register_labels(_labels)
|
||||
|
||||
# Components
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
_new = initializer.generate(len(new_labels), distribution)
|
||||
else:
|
||||
_new = initializer.generate(len(new_labels))
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
|
||||
def remove_components(self, indices=None):
|
||||
# Components
|
||||
mask = super().remove_components(indices)
|
||||
|
||||
# Labels
|
||||
_labels = self._labels[mask]
|
||||
self._register_labels(_labels)
|
||||
|
||||
@property
|
||||
def component_labels(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._labels.detach()
|
||||
|
||||
def forward(self):
|
||||
return super().forward(), self._labels
|
||||
|
||||
|
||||
class ReasoningComponents(Components):
|
||||
r"""ReasoningComponents generate a set of components and a set of reasoning matrices.
|
||||
|
||||
Every Component has a reasoning matrix assigned.
|
||||
|
||||
A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The
|
||||
first element is called positive reasoning :math:`p`, the second negative
|
||||
reasoning :math:`n`. A components can reason in favour (positive) of a
|
||||
class, against (negative) a class or not at all (neutral).
|
||||
|
||||
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
|
||||
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
|
||||
three element probability distribution.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
reasonings=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
components, reasonings = initialized_components
|
||||
|
||||
super().__init__(initialized_components=components)
|
||||
self.register_parameter("_reasonings", reasonings)
|
||||
else:
|
||||
self._initialize_reasonings(reasonings)
|
||||
super().__init__(len(self._reasonings), initializer=initializer)
|
||||
|
||||
def _initialize_reasonings(self, reasonings):
|
||||
if isinstance(reasonings, tuple):
|
||||
num_classes, num_components = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
||||
|
||||
_reasonings = reasonings.generate()
|
||||
self.register_parameter("_reasonings", _reasonings)
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Returns Reasoning Matrix.
|
||||
|
||||
Dimension NxCx2
|
||||
|
||||
"""
|
||||
return self._reasonings.detach()
|
||||
|
||||
def forward(self):
|
||||
return super().forward(), self._reasonings
|
@ -1,233 +0,0 @@
|
||||
"""ProtoTroch Initializers."""
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
def parse_data_arg(data_arg):
|
||||
if isinstance(data_arg, Dataset):
|
||||
data_arg = DataLoader(data_arg, batch_size=len(data_arg))
|
||||
|
||||
if isinstance(data_arg, DataLoader):
|
||||
data = torch.tensor([])
|
||||
targets = torch.tensor([])
|
||||
for x, y in data_arg:
|
||||
data = torch.cat([data, x])
|
||||
targets = torch.cat([targets, y])
|
||||
else:
|
||||
data, targets = data_arg
|
||||
if not isinstance(data, torch.Tensor):
|
||||
wmsg = f"Converting data to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
data = torch.Tensor(data)
|
||||
if not isinstance(targets, torch.Tensor):
|
||||
wmsg = f"Converting targets to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
targets = torch.Tensor(targets)
|
||||
return data, targets
|
||||
|
||||
|
||||
def get_subinitializers(data, targets, clabels, subinit_type):
|
||||
initializers = dict()
|
||||
for clabel in clabels:
|
||||
class_data = data[targets == clabel]
|
||||
class_initializer = subinit_type(class_data)
|
||||
initializers[clabel] = (class_initializer)
|
||||
return initializers
|
||||
|
||||
|
||||
# Components
|
||||
class ComponentsInitializer(object):
|
||||
def generate(self, number_of_components):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class DimensionAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
if isinstance(dims, Iterable):
|
||||
self.components_dims = tuple(dims)
|
||||
else:
|
||||
self.components_dims = (dims, )
|
||||
|
||||
|
||||
class OnesInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, dims, scale=1.0):
|
||||
super().__init__(dims)
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.ones(gen_dims) * self.scale
|
||||
|
||||
|
||||
class ZerosInitializer(DimensionAwareInitializer):
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.zeros(gen_dims)
|
||||
|
||||
|
||||
class UniformInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0):
|
||||
super().__init__(dims)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.ones(gen_dims).uniform_(self.minimum,
|
||||
self.maximum) * self.scale
|
||||
|
||||
|
||||
class DataAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
super().__init__()
|
||||
self.data = data
|
||||
self.transform = transform
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class SelectionInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||
return self.transform(self.data[indices])
|
||||
|
||||
|
||||
class MeanInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
mean = torch.mean(self.data, dim=0)
|
||||
repeat_dim = [length] + [1] * len(mean.shape)
|
||||
return self.transform(mean.repeat(repeat_dim))
|
||||
|
||||
|
||||
class ClassAwareInitializer(DataAwareInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
data, targets = parse_data_arg(data)
|
||||
super().__init__(data, transform)
|
||||
self.targets = targets
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
def _get_samples_from_initializer(self, length, dist):
|
||||
if not dist:
|
||||
per_class = length // self.num_classes
|
||||
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
||||
if isinstance(dist, list):
|
||||
dist = dict(zip(self.clabels, dist))
|
||||
samples = [self.initializers[k].generate(n) for k, n in dist.items()]
|
||||
out = torch.vstack(samples)
|
||||
with torch.no_grad():
|
||||
out = self.transform(out)
|
||||
return out
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.targets
|
||||
|
||||
|
||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
def __init__(self, data, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
self.initializers = get_subinitializers(self.data, self.targets,
|
||||
self.clabels, MeanInitializer)
|
||||
|
||||
def generate(self, length, dist):
|
||||
samples = self._get_samples_from_initializer(length, dist)
|
||||
return samples
|
||||
|
||||
|
||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
def __init__(self, data, noise=None, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
self.noise = noise
|
||||
self.initializers = get_subinitializers(self.data, self.targets,
|
||||
self.clabels,
|
||||
SelectionInitializer)
|
||||
|
||||
def add_noise_v1(self, x):
|
||||
return x + self.noise
|
||||
|
||||
def add_noise_v2(self, x):
|
||||
"""Shifts some dimensions of the data randomly."""
|
||||
n1 = torch.rand_like(x)
|
||||
n2 = torch.rand_like(x)
|
||||
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
|
||||
return x + (self.noise * mask)
|
||||
|
||||
def generate(self, length, dist):
|
||||
samples = self._get_samples_from_initializer(length, dist)
|
||||
if self.noise is not None:
|
||||
samples = self.add_noise_v1(samples)
|
||||
return samples
|
||||
|
||||
|
||||
# Omega matrix
|
||||
class PcaInitializer(DataAwareInitializer):
|
||||
def generate(self, shape):
|
||||
(input_dim, latent_dim) = shape
|
||||
(_, eigVal, eigVec) = torch.pca_lowrank(self.data, q=latent_dim)
|
||||
return eigVec
|
||||
|
||||
|
||||
# Labels
|
||||
class LabelsInitializer:
|
||||
def generate(self):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class UnequalLabelsInitializer(LabelsInitializer):
|
||||
def __init__(self, dist, clabels=None):
|
||||
self.dist = dist
|
||||
self.clabels = clabels or range(len(self.dist))
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
return self.dist
|
||||
|
||||
def generate(self):
|
||||
targets = list(
|
||||
chain(*[[i] * n for i, n in zip(self.clabels, self.dist)]))
|
||||
return torch.LongTensor(targets)
|
||||
|
||||
|
||||
class EqualLabelsInitializer(LabelsInitializer):
|
||||
def __init__(self, classes, per_class):
|
||||
self.classes = classes
|
||||
self.per_class = per_class
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
return self.classes * [self.per_class]
|
||||
|
||||
def generate(self):
|
||||
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
||||
|
||||
|
||||
# Reasonings
|
||||
class ReasoningsInitializer:
|
||||
def generate(self, length):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||
def __init__(self, classes, length):
|
||||
self.classes = classes
|
||||
self.length = length
|
||||
|
||||
def generate(self):
|
||||
return torch.zeros((self.length, self.classes, 2))
|
||||
|
||||
|
||||
# Aliases
|
||||
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
||||
SMI = StratifiedMeanInitializer
|
||||
Random = RandomInitializer = UniformInitializer
|
||||
Zeros = ZerosInitializer
|
||||
Ones = OnesInitializer
|
||||
PCA = PcaInitializer
|
10
prototorch/core/__init__.py
Normal file
10
prototorch/core/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""ProtoTorch core"""
|
||||
|
||||
from .competitions import *
|
||||
from .components import *
|
||||
from .distances import *
|
||||
from .initializers import *
|
||||
from .losses import *
|
||||
from .pooling import *
|
||||
from .similarities import *
|
||||
from .transforms import *
|
89
prototorch/core/competitions.py
Normal file
89
prototorch/core/competitions.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""ProtoTorch competitions"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def wtac(distances: torch.Tensor, labels: torch.LongTensor):
|
||||
"""Winner-Takes-All-Competition.
|
||||
|
||||
Returns the labels corresponding to the winners.
|
||||
|
||||
"""
|
||||
winning_indices = torch.min(distances, dim=1).indices
|
||||
winning_labels = labels[winning_indices].squeeze()
|
||||
return winning_labels
|
||||
|
||||
|
||||
def knnc(distances: torch.Tensor, labels: torch.LongTensor, k: int = 1):
|
||||
"""K-Nearest-Neighbors-Competition.
|
||||
|
||||
Returns the labels corresponding to the winners.
|
||||
|
||||
"""
|
||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
||||
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
||||
return winning_labels
|
||||
|
||||
|
||||
def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
|
||||
"""Classification-By-Components Competition.
|
||||
|
||||
Returns probability distributions over the classes.
|
||||
|
||||
`detections` must be of shape [batch_size, num_components].
|
||||
`reasonings` must be of shape [num_components, num_classes, 2].
|
||||
|
||||
"""
|
||||
A, B = reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||
pk = A
|
||||
nk = (1 - A) * B
|
||||
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
||||
probs = numerator / (pk + nk).sum(1)
|
||||
return probs
|
||||
|
||||
|
||||
class WTAC(torch.nn.Module):
|
||||
"""Winner-Takes-All-Competition Layer.
|
||||
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, distances, labels):
|
||||
return wtac(distances, labels)
|
||||
|
||||
|
||||
class LTAC(torch.nn.Module):
|
||||
"""Loser-Takes-All-Competition Layer.
|
||||
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, probs, labels):
|
||||
return wtac(-1.0 * probs, labels)
|
||||
|
||||
|
||||
class KNNC(torch.nn.Module):
|
||||
"""K-Nearest-Neighbors-Competition.
|
||||
|
||||
Thin wrapper over the `knnc` function.
|
||||
|
||||
"""
|
||||
def __init__(self, k=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.k = k
|
||||
|
||||
def forward(self, distances, labels):
|
||||
return knnc(distances, labels, k=self.k)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"k: {self.k}"
|
||||
|
||||
|
||||
class CBCC(torch.nn.Module):
|
||||
"""Classification-By-Components Competition.
|
||||
|
||||
Thin wrapper over the `cbcc` function.
|
||||
|
||||
"""
|
||||
def forward(self, detections, reasonings):
|
||||
return cbcc(detections, reasonings)
|
370
prototorch/core/components.py
Normal file
370
prototorch/core/components.py
Normal file
@ -0,0 +1,370 @@
|
||||
"""ProtoTorch components"""
|
||||
|
||||
import inspect
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ..utils import parse_distribution
|
||||
from .initializers import (
|
||||
AbstractClassAwareCompInitializer,
|
||||
AbstractComponentsInitializer,
|
||||
AbstractLabelsInitializer,
|
||||
AbstractReasoningsInitializer,
|
||||
LabelsInitializer,
|
||||
PurePositiveReasoningsInitializer,
|
||||
RandomReasoningsInitializer,
|
||||
)
|
||||
|
||||
|
||||
def validate_initializer(initializer, instanceof):
|
||||
"""Check if the initializer is valid."""
|
||||
if not isinstance(initializer, instanceof):
|
||||
emsg = f"`initializer` has to be an instance " \
|
||||
f"of some subtype of {instanceof}. " \
|
||||
f"You have provided: {initializer} instead. "
|
||||
helpmsg = ""
|
||||
if inspect.isclass(initializer):
|
||||
helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \
|
||||
f"with the brackets instead of just {initializer.__name__}?"
|
||||
raise TypeError(emsg + helpmsg)
|
||||
return True
|
||||
|
||||
|
||||
def gencat(ins, attr, init, *iargs, **ikwargs):
|
||||
"""Generate new items and concatenate with existing items."""
|
||||
new_items = init.generate(*iargs, **ikwargs)
|
||||
if hasattr(ins, attr):
|
||||
items = torch.cat([getattr(ins, attr), new_items])
|
||||
else:
|
||||
items = new_items
|
||||
return items, new_items
|
||||
|
||||
|
||||
def removeind(ins, attr, indices):
|
||||
"""Remove items at specified indices."""
|
||||
mask = torch.ones(len(ins), dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
items = getattr(ins, attr)[mask]
|
||||
return items, mask
|
||||
|
||||
|
||||
def get_cikwargs(init, distribution):
|
||||
"""Return appropriate key-word arguments for a component initializer."""
|
||||
if isinstance(init, AbstractClassAwareCompInitializer):
|
||||
cikwargs = dict(distribution=distribution)
|
||||
else:
|
||||
distribution = parse_distribution(distribution)
|
||||
num_components = sum(distribution.values())
|
||||
cikwargs = dict(num_components=num_components)
|
||||
return cikwargs
|
||||
|
||||
|
||||
class AbstractComponents(torch.nn.Module):
|
||||
"""Abstract class for all components modules."""
|
||||
@property
|
||||
def num_components(self):
|
||||
"""Current number of components."""
|
||||
return len(self._components)
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Detached Tensor containing the components."""
|
||||
return self._components.detach().cpu()
|
||||
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
|
||||
def extra_repr(self):
|
||||
return f"components: (shape: {tuple(self._components.shape)})"
|
||||
|
||||
def __len__(self):
|
||||
return self.num_components
|
||||
|
||||
|
||||
class Components(AbstractComponents):
|
||||
"""A set of adaptable Tensors."""
|
||||
def __init__(self, num_components: int,
|
||||
initializer: AbstractComponentsInitializer):
|
||||
super().__init__()
|
||||
self.add_components(num_components, initializer)
|
||||
|
||||
def add_components(self, num_components: int,
|
||||
initializer: AbstractComponentsInitializer):
|
||||
"""Generate and add new components."""
|
||||
assert validate_initializer(initializer, AbstractComponentsInitializer)
|
||||
_components, new_components = gencat(self, "_components", initializer,
|
||||
num_components)
|
||||
self._register_components(_components)
|
||||
return new_components
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components at specified indices."""
|
||||
_components, mask = removeind(self, "_components", indices)
|
||||
self._register_components(_components)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components parameter Tensor."""
|
||||
return self._components
|
||||
|
||||
|
||||
class AbstractLabels(torch.nn.Module):
|
||||
"""Abstract class for all labels modules."""
|
||||
@property
|
||||
def labels(self):
|
||||
return self._labels.cpu()
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
return len(self._labels)
|
||||
|
||||
@property
|
||||
def unique_labels(self):
|
||||
return torch.unique(self._labels)
|
||||
|
||||
@property
|
||||
def num_unique(self):
|
||||
return len(self.unique_labels)
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
unique, counts = torch.unique(self._labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
return dict(zip(unique.tolist(), counts.tolist()))
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
def extra_repr(self):
|
||||
r = f"num_labels: {self.num_labels}, num_unique: {self.num_unique}"
|
||||
if len(self.distribution) < 11: # avoid lengthy representations
|
||||
d = self.distribution
|
||||
unique, counts = list(d.keys()), list(d.values())
|
||||
r += f", unique: {unique}, counts: {counts}"
|
||||
return r
|
||||
|
||||
def __len__(self):
|
||||
return self.num_labels
|
||||
|
||||
|
||||
class Labels(AbstractLabels):
|
||||
"""A set of standalone labels."""
|
||||
def __init__(self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||
super().__init__()
|
||||
self.add_labels(distribution, initializer)
|
||||
|
||||
def add_labels(
|
||||
self,
|
||||
distribution: Union[dict, tuple, list],
|
||||
initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||
"""Generate and add new labels."""
|
||||
assert validate_initializer(initializer, AbstractLabelsInitializer)
|
||||
_labels, new_labels = gencat(self, "_labels", initializer,
|
||||
distribution)
|
||||
self._register_labels(_labels)
|
||||
return new_labels
|
||||
|
||||
def remove_labels(self, indices):
|
||||
"""Remove labels at specified indices."""
|
||||
_labels, mask = removeind(self, "_labels", indices)
|
||||
self._register_labels(_labels)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the labels."""
|
||||
return self._labels
|
||||
|
||||
|
||||
class LabeledComponents(AbstractComponents):
|
||||
"""A set of adaptable components and corresponding unadaptable labels."""
|
||||
def __init__(
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
components_initializer: AbstractComponentsInitializer,
|
||||
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||
super().__init__()
|
||||
self.add_components(distribution, components_initializer,
|
||||
labels_initializer)
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
unique, counts = torch.unique(self._labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
return dict(zip(unique.tolist(), counts.tolist()))
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.distribution.keys())
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
"""Tensor containing the component labels."""
|
||||
return self._labels.cpu()
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
def add_components(
|
||||
self,
|
||||
distribution,
|
||||
components_initializer,
|
||||
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||
"""Generate and add new components and labels."""
|
||||
assert validate_initializer(components_initializer,
|
||||
AbstractComponentsInitializer)
|
||||
assert validate_initializer(labels_initializer,
|
||||
AbstractLabelsInitializer)
|
||||
cikwargs = get_cikwargs(components_initializer, distribution)
|
||||
_components, new_components = gencat(self, "_components",
|
||||
components_initializer,
|
||||
**cikwargs)
|
||||
_labels, new_labels = gencat(self, "_labels", labels_initializer,
|
||||
distribution)
|
||||
self._register_components(_components)
|
||||
self._register_labels(_labels)
|
||||
return new_components, new_labels
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components and labels at specified indices."""
|
||||
_components, mask = removeind(self, "_components", indices)
|
||||
_labels, mask = removeind(self, "_labels", indices)
|
||||
self._register_components(_components)
|
||||
self._register_labels(_labels)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components parameter Tensor and labels."""
|
||||
return self._components, self._labels
|
||||
|
||||
|
||||
class Reasonings(torch.nn.Module):
|
||||
"""A set of standalone reasoning matrices.
|
||||
|
||||
The `reasonings` tensor is of shape [num_components, num_classes, 2].
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
initializer:
|
||||
AbstractReasoningsInitializer = RandomReasoningsInitializer()):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._reasonings.shape[1]
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Tensor containing the reasoning matrices."""
|
||||
return self._reasonings.detach().cpu()
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_buffer("_reasonings", reasonings)
|
||||
|
||||
def add_reasonings(
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
initializer:
|
||||
AbstractReasoningsInitializer = RandomReasoningsInitializer()):
|
||||
"""Generate and add new reasonings."""
|
||||
assert validate_initializer(initializer, AbstractReasoningsInitializer)
|
||||
_reasonings, new_reasonings = gencat(self, "_reasonings", initializer,
|
||||
distribution)
|
||||
self._register_reasonings(_reasonings)
|
||||
return new_reasonings
|
||||
|
||||
def remove_reasonings(self, indices):
|
||||
"""Remove reasonings at specified indices."""
|
||||
_reasonings, mask = removeind(self, "_reasonings", indices)
|
||||
self._register_reasonings(_reasonings)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the reasonings."""
|
||||
return self._reasonings
|
||||
|
||||
|
||||
class ReasoningComponents(AbstractComponents):
|
||||
r"""A set of components and a corresponding adapatable reasoning matrices.
|
||||
|
||||
Every component has its own reasoning matrix.
|
||||
|
||||
A reasoning matrix is an Nx2 matrix, where N is the number of classes. The
|
||||
first element is called positive reasoning :math:`p`, the second negative
|
||||
reasoning :math:`n`. A components can reason in favour (positive) of a
|
||||
class, against (negative) a class or not at all (neutral).
|
||||
|
||||
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
|
||||
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
|
||||
three element probability distribution.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
components_initializer: AbstractComponentsInitializer,
|
||||
reasonings_initializer:
|
||||
AbstractReasoningsInitializer = PurePositiveReasoningsInitializer()):
|
||||
super().__init__()
|
||||
self.add_components(distribution, components_initializer,
|
||||
reasonings_initializer)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._reasonings.shape[1]
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Tensor containing the reasoning matrices."""
|
||||
return self._reasonings.detach().cpu()
|
||||
|
||||
@property
|
||||
def reasoning_matrices(self):
|
||||
"""Reasoning matrices for each class."""
|
||||
with torch.no_grad():
|
||||
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||
pk = A
|
||||
nk = (1 - pk) * B
|
||||
ik = 1 - pk - nk
|
||||
matrices = torch.stack([pk, nk, ik], dim=-1).permute(1, 2, 0)
|
||||
return matrices.cpu()
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_parameter("_reasonings", Parameter(reasonings))
|
||||
|
||||
def add_components(self, distribution, components_initializer,
|
||||
reasonings_initializer: AbstractReasoningsInitializer):
|
||||
"""Generate and add new components and reasonings."""
|
||||
assert validate_initializer(components_initializer,
|
||||
AbstractComponentsInitializer)
|
||||
assert validate_initializer(reasonings_initializer,
|
||||
AbstractReasoningsInitializer)
|
||||
cikwargs = get_cikwargs(components_initializer, distribution)
|
||||
_components, new_components = gencat(self, "_components",
|
||||
components_initializer,
|
||||
**cikwargs)
|
||||
_reasonings, new_reasonings = gencat(self, "_reasonings",
|
||||
reasonings_initializer,
|
||||
distribution)
|
||||
self._register_components(_components)
|
||||
self._register_reasonings(_reasonings)
|
||||
return new_components, new_reasonings
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components and reasonings at specified indices."""
|
||||
_components, mask = removeind(self, "_components", indices)
|
||||
_reasonings, mask = removeind(self, "_reasonings", indices)
|
||||
self._register_components(_components)
|
||||
self._register_reasonings(_reasonings)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components and reasonings."""
|
||||
return self._components, self._reasonings
|
98
prototorch/core/distances.py
Normal file
98
prototorch/core/distances.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""ProtoTorch distances"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
|
||||
|
||||
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
||||
|
||||
**Alias:**
|
||||
``prototorch.functions.distances.sed``
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
expanded_x = x.unsqueeze(dim=1)
|
||||
batchwise_difference = y - expanded_x
|
||||
differences_raised = torch.pow(batchwise_difference, 2)
|
||||
distances = torch.sum(differences_raised, axis=2)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance(x, y):
|
||||
r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
|
||||
|
||||
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
||||
|
||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
distances_raised = squared_euclidean_distance(x, y)
|
||||
distances = torch.sqrt(distances_raised)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_v2(x, y):
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
diff = y - x.unsqueeze(1)
|
||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||
# batch diagonal. See:
|
||||
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
||||
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
||||
# print(f"{distances.shape=}") # (nx, ny)
|
||||
return distances
|
||||
|
||||
|
||||
def lpnorm_distance(x, y, p):
|
||||
r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
||||
Also known as Minkowski distance.
|
||||
|
||||
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
||||
|
||||
Calls ``torch.cdist``
|
||||
|
||||
:param p: p parameter of the lp norm
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
distances = torch.cdist(x, y, p=p)
|
||||
return distances
|
||||
|
||||
|
||||
def omega_distance(x, y, omega):
|
||||
r"""Omega distance.
|
||||
|
||||
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
||||
|
||||
:param `torch.tensor` omega: Two dimensional matrix
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
projected_x = x @ omega
|
||||
projected_y = y @ omega
|
||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||
return distances
|
||||
|
||||
|
||||
def lomega_distance(x, y, omegas):
|
||||
r"""Localized Omega distance.
|
||||
|
||||
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
||||
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
projected_x = x @ omegas
|
||||
projected_y = torch.diagonal(y @ omegas).T
|
||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||
batchwise_difference = expanded_y - projected_x
|
||||
differences_squared = batchwise_difference**2
|
||||
distances = torch.sum(differences_squared, dim=2)
|
||||
distances = distances.permute(1, 0)
|
||||
return distances
|
||||
|
||||
|
||||
# Aliases
|
||||
sed = squared_euclidean_distance
|
494
prototorch/core/initializers.py
Normal file
494
prototorch/core/initializers.py
Normal file
@ -0,0 +1,494 @@
|
||||
"""ProtoTorch code initializers"""
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import (
|
||||
Callable,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import parse_data_arg, parse_distribution
|
||||
|
||||
|
||||
# Components
|
||||
class AbstractComponentsInitializer(ABC):
|
||||
"""Abstract class for all components initializers."""
|
||||
...
|
||||
|
||||
|
||||
class LiteralCompInitializer(AbstractComponentsInitializer):
|
||||
"""'Generate' the provided components.
|
||||
|
||||
Use this to 'generate' pre-initialized components elsewhere.
|
||||
|
||||
"""
|
||||
def __init__(self, components):
|
||||
self.components = components
|
||||
|
||||
def generate(self, num_components: int = 0):
|
||||
"""Ignore `num_components` and simply return `self.components`."""
|
||||
if not isinstance(self.components, torch.Tensor):
|
||||
wmsg = f"Converting components to {torch.Tensor}..."
|
||||
warnings.warn(wmsg)
|
||||
self.components = torch.Tensor(self.components)
|
||||
return self.components
|
||||
|
||||
|
||||
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all dimension-aware components initializers."""
|
||||
def __init__(self, shape: Union[Iterable, int]):
|
||||
if isinstance(shape, Iterable):
|
||||
self.component_shape = tuple(shape)
|
||||
else:
|
||||
self.component_shape = (shape, )
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, num_components: int):
|
||||
...
|
||||
|
||||
|
||||
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate zeros corresponding to the components shape."""
|
||||
def generate(self, num_components: int):
|
||||
components = torch.zeros((num_components, ) + self.component_shape)
|
||||
return components
|
||||
|
||||
|
||||
class OnesCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate ones corresponding to the components shape."""
|
||||
def generate(self, num_components: int):
|
||||
components = torch.ones((num_components, ) + self.component_shape)
|
||||
return components
|
||||
|
||||
|
||||
class FillValueCompInitializer(OnesCompInitializer):
|
||||
"""Generate components with the provided `fill_value`."""
|
||||
def __init__(self, shape, fill_value: float = 1.0):
|
||||
super().__init__(shape)
|
||||
self.fill_value = fill_value
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = ones.fill_(self.fill_value)
|
||||
return components
|
||||
|
||||
|
||||
class UniformCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a continuous uniform distribution."""
|
||||
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = self.scale * ones.uniform_(self.minimum, self.maximum)
|
||||
return components
|
||||
|
||||
|
||||
class RandomNormalCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a standard normal distribution."""
|
||||
def __init__(self, shape, shift=0.0, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.shift = shift
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = self.scale * (torch.randn_like(ones) + self.shift)
|
||||
return components
|
||||
|
||||
|
||||
class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all data-aware components initializers.
|
||||
|
||||
Components generated by data-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
`data` has to be a torch tensor.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
noise: float = 0.0,
|
||||
transform: Callable = torch.nn.Identity()):
|
||||
self.data = data
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
|
||||
def generate_end_hook(self, samples):
|
||||
drift = torch.rand_like(samples) * self.noise
|
||||
components = self.transform(samples + drift)
|
||||
return components
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, num_components: int):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""'Generate' the components from the provided data."""
|
||||
def generate(self, num_components: int = 0):
|
||||
"""Ignore `num_components` and simply return transformed `self.data`."""
|
||||
components = self.generate_end_hook(self.data)
|
||||
return components
|
||||
|
||||
|
||||
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""Generate components by uniformly sampling from the provided data."""
|
||||
def generate(self, num_components: int):
|
||||
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
||||
samples = self.data[indices]
|
||||
components = self.generate_end_hook(samples)
|
||||
return components
|
||||
|
||||
|
||||
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""Generate components by computing the mean of the provided data."""
|
||||
def generate(self, num_components: int):
|
||||
mean = self.data.mean(dim=0)
|
||||
repeat_dim = [num_components] + [1] * len(mean.shape)
|
||||
samples = mean.repeat(repeat_dim)
|
||||
components = self.generate_end_hook(samples)
|
||||
return components
|
||||
|
||||
|
||||
class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all class-aware components initializers.
|
||||
|
||||
Components generated by class-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
`data` could be a torch Dataset or DataLoader or a list/tuple of data and
|
||||
target tensors.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data,
|
||||
noise: float = 0.0,
|
||||
transform: Callable = torch.nn.Identity()):
|
||||
self.data, self.targets = parse_data_arg(data)
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
def generate_end_hook(self, samples):
|
||||
drift = torch.rand_like(samples) * self.noise
|
||||
components = self.transform(samples + drift)
|
||||
return components
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.targets
|
||||
|
||||
|
||||
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
||||
"""'Generate' components from provided data and requested distribution."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
"""Ignore `distribution` and simply return transformed `self.data`."""
|
||||
components = self.generate_end_hook(self.data)
|
||||
return components
|
||||
|
||||
|
||||
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
||||
"""Abstract class for all stratified components initializers."""
|
||||
@property
|
||||
@abstractmethod
|
||||
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
|
||||
...
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
components = torch.tensor([])
|
||||
for k, v in distribution.items():
|
||||
stratified_data = self.data[self.targets == k]
|
||||
initializer = self.subinit_type(
|
||||
stratified_data,
|
||||
noise=self.noise,
|
||||
transform=self.transform,
|
||||
)
|
||||
samples = initializer.generate(num_components=v)
|
||||
components = torch.cat([components, samples])
|
||||
return components
|
||||
|
||||
|
||||
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
||||
"""Generate components using stratified sampling from the provided data."""
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return SelectionCompInitializer
|
||||
|
||||
|
||||
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
||||
"""Generate components at stratified means of the provided data."""
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return MeanCompInitializer
|
||||
|
||||
|
||||
# Labels
|
||||
class AbstractLabelsInitializer(ABC):
|
||||
"""Abstract class for all labels initializers."""
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
|
||||
|
||||
class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
||||
"""'Generate' the provided labels.
|
||||
|
||||
Use this to 'generate' pre-initialized labels elsewhere.
|
||||
|
||||
"""
|
||||
def __init__(self, labels):
|
||||
self.labels = labels
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
"""Ignore `distribution` and simply return `self.labels`.
|
||||
|
||||
Convert to long tensor, if necessary.
|
||||
"""
|
||||
labels = self.labels
|
||||
if not isinstance(labels, torch.LongTensor):
|
||||
wmsg = f"Converting labels to {torch.LongTensor}..."
|
||||
warnings.warn(wmsg)
|
||||
labels = torch.LongTensor(labels)
|
||||
return labels
|
||||
|
||||
|
||||
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
||||
"""'Generate' the labels from a torch Dataset."""
|
||||
def __init__(self, data):
|
||||
self.data, self.targets = parse_data_arg(data)
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
"""Ignore `num_components` and simply return `self.targets`."""
|
||||
return self.targets
|
||||
|
||||
|
||||
class LabelsInitializer(AbstractLabelsInitializer):
|
||||
"""Generate labels from `distribution`."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
labels_list = []
|
||||
for k, v in distribution.items():
|
||||
labels_list.extend([k] * v)
|
||||
labels = torch.LongTensor(labels_list)
|
||||
return labels
|
||||
|
||||
|
||||
class OneHotLabelsInitializer(LabelsInitializer):
|
||||
"""Generate one-hot-encoded labels from `distribution`."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_classes = len(distribution.keys())
|
||||
# this breaks if class labels are not [0,...,nclasses]
|
||||
labels = torch.eye(num_classes)[super().generate(distribution)]
|
||||
return labels
|
||||
|
||||
|
||||
# Reasonings
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
def __init__(self, components_first: bool = True):
|
||||
self.components_first = components_first
|
||||
|
||||
def compute_shape(self, distribution):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_components = sum(distribution.values())
|
||||
num_classes = len(distribution.keys())
|
||||
return (num_components, num_classes, 2)
|
||||
|
||||
def generate_end_hook(self, reasonings):
|
||||
if not self.components_first:
|
||||
reasonings = reasonings.permute(2, 1, 0)
|
||||
return reasonings
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
|
||||
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""'Generate' the provided reasonings.
|
||||
|
||||
Use this to 'generate' pre-initialized reasonings elsewhere.
|
||||
|
||||
"""
|
||||
def __init__(self, reasonings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.reasonings = reasonings
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
"""Ignore `distributuion` and simply return self.reasonings."""
|
||||
reasonings = self.reasonings
|
||||
if not isinstance(reasonings, torch.Tensor):
|
||||
wmsg = f"Converting reasonings to {torch.Tensor}..."
|
||||
warnings.warn(wmsg)
|
||||
reasonings = torch.Tensor(reasonings)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with zeros."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
reasonings = torch.zeros(*shape)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with ones."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
reasonings = torch.ones(*shape)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are randomly initialized."""
|
||||
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Each component reasons positively for exactly one class."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
||||
A = OneHotLabelsInitializer().generate(distribution)
|
||||
B = torch.zeros(num_components, num_classes)
|
||||
reasonings = torch.stack([A, B], dim=-1)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
# Transforms
|
||||
class AbstractTransformInitializer(ABC):
|
||||
"""Abstract class for all transform initializers."""
|
||||
...
|
||||
|
||||
|
||||
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
||||
"""Abstract class for all linear transform initializers."""
|
||||
def __init__(self, out_dim_first: bool = False):
|
||||
self.out_dim_first = out_dim_first
|
||||
|
||||
def generate_end_hook(self, weights):
|
||||
if self.out_dim_first:
|
||||
weights = weights.permute(1, 0)
|
||||
return weights
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
|
||||
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with zeros."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.zeros(in_dim, out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with ones."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.ones(in_dim, out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with the largest possible identity matrix."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.zeros(in_dim, out_dim)
|
||||
I = torch.eye(min(in_dim, out_dim))
|
||||
weights[:I.shape[0], :I.shape[1]] = I
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
||||
"""Abstract class for all data-aware linear transform initializers."""
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
noise: float = 0.0,
|
||||
transform: Callable = torch.nn.Identity()):
|
||||
self.data = data
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
|
||||
def generate_end_hook(self, weights: torch.Tensor):
|
||||
drift = torch.rand_like(weights) * self.noise
|
||||
weights = self.transform(weights + drift)
|
||||
if self.out_dim_first:
|
||||
weights = weights.permute(1, 0)
|
||||
return weights
|
||||
|
||||
|
||||
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
|
||||
"""Initialize a matrix with Eigenvectors from the data."""
|
||||
@abstractmethod
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
# Aliases - Components
|
||||
CACI = ClassAwareCompInitializer
|
||||
DACI = DataAwareCompInitializer
|
||||
FVCI = FillValueCompInitializer
|
||||
LCI = LiteralCompInitializer
|
||||
MCI = MeanCompInitializer
|
||||
OCI = OnesCompInitializer
|
||||
RNCI = RandomNormalCompInitializer
|
||||
SCI = SelectionCompInitializer
|
||||
SMCI = StratifiedMeanCompInitializer
|
||||
SSCI = StratifiedSelectionCompInitializer
|
||||
UCI = UniformCompInitializer
|
||||
ZCI = ZerosCompInitializer
|
||||
|
||||
# Aliases - Labels
|
||||
DLI = DataAwareLabelsInitializer
|
||||
LI = LabelsInitializer
|
||||
LLI = LiteralLabelsInitializer
|
||||
OHLI = OneHotLabelsInitializer
|
||||
|
||||
# Aliases - Reasonings
|
||||
LRI = LiteralReasoningsInitializer
|
||||
ORI = OnesReasoningsInitializer
|
||||
PPRI = PurePositiveReasoningsInitializer
|
||||
RRI = RandomReasoningsInitializer
|
||||
ZRI = ZerosReasoningsInitializer
|
||||
|
||||
# Aliases - Transforms
|
||||
Eye = EyeTransformInitializer
|
||||
OLTI = OnesLinearTransformInitializer
|
||||
ZLTI = ZerosLinearTransformInitializer
|
||||
PCALTI = PCALinearTransformInitializer
|
@ -1,8 +1,11 @@
|
||||
"""ProtoTorch loss functions."""
|
||||
"""ProtoTorch losses"""
|
||||
|
||||
import torch
|
||||
|
||||
from ..nn.activations import get_activation
|
||||
|
||||
|
||||
# Helpers
|
||||
def _get_matcher(targets, labels):
|
||||
"""Returns a boolean tensor."""
|
||||
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
||||
@ -28,6 +31,7 @@ def _get_dp_dm(distances, targets, plabels, with_indices=False):
|
||||
return dp.values, dm.values
|
||||
|
||||
|
||||
# GLVQ
|
||||
def glvq_loss(distances, target_labels, prototype_labels):
|
||||
"""GLVQ loss function with support for one-hot labels."""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
@ -92,3 +96,76 @@ def rslvq_loss(probabilities, targets, prototype_labels):
|
||||
likelihood = correct / whole
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return -1.0 * log_likelihood
|
||||
|
||||
|
||||
def margin_loss(y_pred, y_true, margin=0.3):
|
||||
"""Compute the margin loss."""
|
||||
dp = torch.sum(y_true * y_pred, dim=-1)
|
||||
dm = torch.max(y_pred - y_true, dim=-1).values
|
||||
return torch.nn.functional.relu(dm - dp + margin)
|
||||
|
||||
|
||||
class GLVQLoss(torch.nn.Module):
|
||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
self.beta = torch.tensor(beta)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
||||
|
||||
|
||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||
def __init__(self,
|
||||
margin=0.3,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction="mean"):
|
||||
super().__init__(size_average, reduce, reduction)
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
return margin_loss(y_pred, y_true, self.margin)
|
||||
|
||||
|
||||
class NeuralGasEnergy(torch.nn.Module):
|
||||
def __init__(self, lm, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.lm = lm
|
||||
|
||||
def forward(self, d):
|
||||
order = torch.argsort(d, dim=1)
|
||||
ranks = torch.argsort(order, dim=1)
|
||||
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
|
||||
|
||||
return cost, order
|
||||
|
||||
def extra_repr(self):
|
||||
return f"lambda: {self.lm}"
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, lm):
|
||||
return torch.exp(-rankings / lm)
|
||||
|
||||
|
||||
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
||||
def __init__(self, topology_layer, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topology_layer = topology_layer
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, topology):
|
||||
winner = rankings[:, 0]
|
||||
|
||||
weights = torch.zeros_like(rankings, dtype=torch.float)
|
||||
weights[torch.arange(rankings.shape[0]), winner] = 1.0
|
||||
|
||||
neighbours = topology.get_neighbours(winner)
|
||||
|
||||
weights[neighbours] = 0.1
|
||||
|
||||
return weights
|
@ -1,4 +1,4 @@
|
||||
"""ProtoTorch pooling functions."""
|
||||
"""ProtoTorch pooling"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
@ -78,3 +78,27 @@ def stratified_prod_pooling(values: torch.Tensor,
|
||||
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
|
||||
fill_value=1.0)
|
||||
return winning_values
|
||||
|
||||
|
||||
class StratifiedSumPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_sum_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedProdPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_prod_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMinPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_min_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMaxPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_max_pooling(values, labels)
|
@ -1,7 +1,19 @@
|
||||
"""ProtoTorch similarity functions."""
|
||||
"""ProtoTorch similarities."""
|
||||
|
||||
import torch
|
||||
|
||||
from .distances import euclidean_distance
|
||||
|
||||
|
||||
def gaussian(x, variance=1.0):
|
||||
return torch.exp(-(x * x) / (2 * variance))
|
||||
|
||||
|
||||
def euclidean_similarity(x, y, variance=1.0):
|
||||
distances = euclidean_distance(x, y)
|
||||
similarities = gaussian(distances, variance)
|
||||
return similarities
|
||||
|
||||
|
||||
def cosine_similarity(x, y):
|
||||
"""Compute the cosine similarity between :math:`x` and :math:`y`.
|
||||
@ -9,6 +21,7 @@ def cosine_similarity(x, y):
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
norm_x = x.pow(2).sum(1).sqrt()
|
||||
norm_y = y.pow(2).sum(1).sqrt()
|
||||
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
43
prototorch/core/transforms.py
Normal file
43
prototorch/core/transforms.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""ProtoTorch transforms"""
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .initializers import (
|
||||
AbstractLinearTransformInitializer,
|
||||
EyeTransformInitializer,
|
||||
)
|
||||
|
||||
|
||||
class LinearTransform(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
initializer:
|
||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
||||
super().__init__()
|
||||
self.set_weights(in_dim, out_dim, initializer)
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
return self._weights.detach().cpu()
|
||||
|
||||
def _register_weights(self, weights):
|
||||
self.register_parameter("_weights", Parameter(weights))
|
||||
|
||||
def set_weights(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
initializer:
|
||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
||||
weights = initializer.generate(in_dim, out_dim)
|
||||
self._register_weights(weights)
|
||||
|
||||
def forward(self, x):
|
||||
return x @ self.weights.T
|
||||
|
||||
|
||||
# Aliases
|
||||
Omega = LinearTransform
|
@ -1,6 +1,12 @@
|
||||
"""ProtoTorch datasets."""
|
||||
"""ProtoTorch datasets"""
|
||||
|
||||
from .abstract import NumpyDataset
|
||||
from .sklearn import Blobs, Circles, Iris, Moons, Random
|
||||
from .sklearn import (
|
||||
Blobs,
|
||||
Circles,
|
||||
Iris,
|
||||
Moons,
|
||||
Random,
|
||||
)
|
||||
from .spiral import Spiral
|
||||
from .tecator import Tecator
|
||||
|
@ -1,10 +1,11 @@
|
||||
"""ProtoTorch abstract dataset classes.
|
||||
"""ProtoTorch abstract dataset classes
|
||||
|
||||
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
|
||||
Based on `torchvision.VisionDataset` and `torchvision.MNIST`.
|
||||
|
||||
For the original code, see:
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -12,15 +13,6 @@ import os
|
||||
import torch
|
||||
|
||||
|
||||
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||
def __init__(self, data, targets):
|
||||
self.data = torch.Tensor(data)
|
||||
self.targets = torch.LongTensor(targets)
|
||||
tensors = [self.data, self.targets]
|
||||
super().__init__(*tensors)
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
"""Abstract dataset class to be inherited."""
|
||||
|
||||
@ -44,7 +36,7 @@ class ProtoDataset(Dataset):
|
||||
training_file = "training.pt"
|
||||
test_file = "test.pt"
|
||||
|
||||
def __init__(self, root, train=True, download=True, verbose=True):
|
||||
def __init__(self, root="", train=True, download=True, verbose=True):
|
||||
super().__init__(root)
|
||||
self.train = train # training set or test set
|
||||
self.verbose = verbose
|
||||
@ -96,3 +88,12 @@ class ProtoDataset(Dataset):
|
||||
|
||||
def _download(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||
def __init__(self, data, targets):
|
||||
self.data = torch.Tensor(data)
|
||||
self.targets = torch.LongTensor(targets)
|
||||
tensors = [self.data, self.targets]
|
||||
super().__init__(*tensors)
|
||||
|
@ -1,5 +0,0 @@
|
||||
"""ProtoTorch functions."""
|
||||
|
||||
from .activations import identity, sigmoid_beta, swish_beta
|
||||
from .competitions import knnc, wtac
|
||||
from .pooling import *
|
@ -1,28 +0,0 @@
|
||||
"""ProtoTorch competition functions."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def wtac(distances: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.LongTensor):
|
||||
"""Winner-Takes-All-Competition.
|
||||
|
||||
Returns the labels corresponding to the winners.
|
||||
|
||||
"""
|
||||
winning_indices = torch.min(distances, dim=1).indices
|
||||
winning_labels = labels[winning_indices].squeeze()
|
||||
return winning_labels
|
||||
|
||||
|
||||
def knnc(distances: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
k: int = 1) -> (torch.LongTensor):
|
||||
"""K-Nearest-Neighbors-Competition.
|
||||
|
||||
Returns the labels corresponding to the winners.
|
||||
|
||||
"""
|
||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
||||
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
||||
return winning_labels
|
@ -1,259 +0,0 @@
|
||||
"""ProtoTorch distance functions."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||
equal_int_shape, get_flat)
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
|
||||
|
||||
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
||||
|
||||
**Alias:**
|
||||
``prototorch.functions.distances.sed``
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
expanded_x = x.unsqueeze(dim=1)
|
||||
batchwise_difference = y - expanded_x
|
||||
differences_raised = torch.pow(batchwise_difference, 2)
|
||||
distances = torch.sum(differences_raised, axis=2)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance(x, y):
|
||||
r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
|
||||
|
||||
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
||||
|
||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
distances_raised = squared_euclidean_distance(x, y)
|
||||
distances = torch.sqrt(distances_raised)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_v2(x, y):
|
||||
x, y = get_flat(x, y)
|
||||
diff = y - x.unsqueeze(1)
|
||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||
# batch diagonal. See:
|
||||
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
||||
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
||||
# print(f"{distances.shape=}") # (nx, ny)
|
||||
return distances
|
||||
|
||||
|
||||
def lpnorm_distance(x, y, p):
|
||||
r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
||||
Also known as Minkowski distance.
|
||||
|
||||
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
||||
|
||||
Calls ``torch.cdist``
|
||||
|
||||
:param p: p parameter of the lp norm
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
distances = torch.cdist(x, y, p=p)
|
||||
return distances
|
||||
|
||||
|
||||
def omega_distance(x, y, omega):
|
||||
r"""Omega distance.
|
||||
|
||||
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
||||
|
||||
:param `torch.tensor` omega: Two dimensional matrix
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
projected_x = x @ omega
|
||||
projected_y = y @ omega
|
||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||
return distances
|
||||
|
||||
|
||||
def lomega_distance(x, y, omegas):
|
||||
r"""Localized Omega distance.
|
||||
|
||||
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
||||
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
projected_x = x @ omegas
|
||||
projected_y = torch.diagonal(y @ omegas).T
|
||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||
batchwise_difference = expanded_y - projected_x
|
||||
differences_squared = batchwise_difference**2
|
||||
distances = torch.sum(differences_squared, dim=2)
|
||||
distances = distances.permute(1, 0)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
||||
r"""Computes an euclidean distances matrix given two distinct vectors.
|
||||
last dimension must be the vector dimension!
|
||||
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
||||
|
||||
- ``x.shape = (number_of_x_vectors, vector_dim)``
|
||||
- ``y.shape = (number_of_y_vectors, vector_dim)``
|
||||
|
||||
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
||||
"""
|
||||
for tensor in [x, y]:
|
||||
if tensor.ndim != 2:
|
||||
raise ValueError(
|
||||
"The tensor dimension must be two. You provide: tensor.ndim=" +
|
||||
str(tensor.ndim) + ".")
|
||||
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||
raise ValueError(
|
||||
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
||||
+ str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" +
|
||||
str(tuple(y.shape)[1]) + ".")
|
||||
|
||||
y = torch.transpose(y)
|
||||
|
||||
diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
|
||||
torch.sum(y**2, axis=0, keepdims=True))
|
||||
|
||||
if not squared:
|
||||
if epsilon == 0:
|
||||
diss = torch.sqrt(diss)
|
||||
else:
|
||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||
|
||||
return diss
|
||||
|
||||
|
||||
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
||||
r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
|
||||
|
||||
For more info about Tangen distances see
|
||||
|
||||
DOI:10.1109/IJCNN.2016.7727534.
|
||||
|
||||
The subspaces is always assumed as transposed and must be orthogonal!
|
||||
For local non sparse signals subspaces must be provided!
|
||||
|
||||
- shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
- shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
- shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||
|
||||
subspace should be orthogonalized
|
||||
Pytorch implementation of Sascha Saralajew's tensorflow code.
|
||||
Translation by Christoph Raab
|
||||
"""
|
||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
||||
subspace_int_shape = tuple(subspaces.shape)
|
||||
|
||||
# check if the shapes are correct
|
||||
_check_shapes(signal_int_shape, proto_int_shape)
|
||||
|
||||
atom_axes = list(range(3, len(signal_int_shape)))
|
||||
# for sparse signals, we use the memory efficient implementation
|
||||
if signal_int_shape[1] == 1:
|
||||
signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
|
||||
|
||||
if len(atom_axes) > 1:
|
||||
protos = torch.reshape(protos, [proto_shape[0], -1])
|
||||
|
||||
if subspaces.ndim == 2:
|
||||
# clean solution without map if the matrix_scope is global
|
||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||
subspaces, torch.transpose(subspaces))
|
||||
|
||||
projected_signals = torch.dot(signals, projectors)
|
||||
projected_protos = torch.dot(protos, projectors)
|
||||
|
||||
diss = euclidean_distance_matrix(projected_signals,
|
||||
projected_protos,
|
||||
squared=squared,
|
||||
epsilon=epsilon)
|
||||
|
||||
diss = torch.reshape(
|
||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||
|
||||
return torch.permute(diss, [0, 2, 1])
|
||||
|
||||
else:
|
||||
|
||||
# no solution without map possible --> memory efficient but slow!
|
||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||
subspaces,
|
||||
subspaces) # K.batch_dot(subspaces, subspaces, [2, 2])
|
||||
|
||||
projected_protos = (protos @ subspaces
|
||||
).T # K.batch_dot(projectors, protos, [1, 1]))
|
||||
|
||||
def projected_norm(projector):
|
||||
return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
||||
|
||||
diss = (torch.transpose(map(projected_norm, projectors)) -
|
||||
2 * torch.dot(signals, projected_protos) +
|
||||
torch.sum(projected_protos**2, axis=0, keepdims=True))
|
||||
|
||||
if not squared:
|
||||
if epsilon == 0:
|
||||
diss = torch.sqrt(diss)
|
||||
else:
|
||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||
|
||||
diss = torch.reshape(
|
||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||
|
||||
return torch.permute(diss, [0, 2, 1])
|
||||
|
||||
else:
|
||||
signals = signals.permute([0, 2, 1] + atom_axes)
|
||||
|
||||
diff = signals - protos
|
||||
|
||||
# global tangent space
|
||||
if subspaces.ndim == 2:
|
||||
# Scope Projectors
|
||||
projectors = subspaces #
|
||||
|
||||
# Scope: Tangentspace Projections
|
||||
diff = torch.reshape(
|
||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||
projected_diff = diff @ projectors
|
||||
projected_diff = torch.reshape(
|
||||
projected_diff,
|
||||
(signal_shape[0], signal_shape[2], signal_shape[1]) +
|
||||
signal_shape[3:],
|
||||
)
|
||||
|
||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
return diss.permute([0, 2, 1])
|
||||
|
||||
# local tangent spaces
|
||||
else:
|
||||
# Scope: Calculate Projectors
|
||||
projectors = subspaces
|
||||
|
||||
# Scope: Tangentspace Projections
|
||||
diff = torch.reshape(
|
||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||
diff = diff.permute([1, 0, 2])
|
||||
projected_diff = torch.bmm(diff, projectors)
|
||||
projected_diff = torch.reshape(
|
||||
projected_diff,
|
||||
(signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||
signal_shape[3:],
|
||||
)
|
||||
|
||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
return diss.permute([1, 0, 2]).squeeze(-1)
|
||||
|
||||
|
||||
# Aliases
|
||||
sed = squared_euclidean_distance
|
@ -1,94 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_flat(*args):
|
||||
rv = [x.view(x.size(0), -1) for x in args]
|
||||
return rv
|
||||
|
||||
|
||||
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
||||
"""Computes the accuracy of a prototype based model.
|
||||
via Winner-Takes-All rule.
|
||||
Requirement:
|
||||
y_pred.shape == y_true.shape
|
||||
unique(y_pred) in plabels
|
||||
"""
|
||||
with torch.no_grad():
|
||||
idx = torch.argmin(y_pred, axis=1)
|
||||
return torch.true_divide(torch.sum(y_true == plabels[idx]),
|
||||
len(y_pred)) * 100
|
||||
|
||||
|
||||
def predict_label(y_pred, plabels):
|
||||
r""" Predicts labels given a prediction of a prototype based model.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return plabels[torch.argmin(y_pred, 1)]
|
||||
|
||||
|
||||
def mixed_shape(inputs):
|
||||
if not torch.is_tensor(inputs):
|
||||
raise ValueError("Input must be a tensor.")
|
||||
else:
|
||||
int_shape = list(inputs.shape)
|
||||
# sometimes int_shape returns mixed integer types
|
||||
int_shape = [int(i) if i is not None else i for i in int_shape]
|
||||
tensor_shape = inputs.shape
|
||||
|
||||
for i, s in enumerate(int_shape):
|
||||
if s is None:
|
||||
int_shape[i] = tensor_shape[i]
|
||||
return tuple(int_shape)
|
||||
|
||||
|
||||
def equal_int_shape(shape_1, shape_2):
|
||||
if not isinstance(shape_1,
|
||||
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
||||
raise ValueError("Input shapes must list or tuple.")
|
||||
for shape in [shape_1, shape_2]:
|
||||
if not all([isinstance(x, int) or x is None for x in shape]):
|
||||
raise ValueError(
|
||||
"Input shapes must be list or tuple of int and None values.")
|
||||
|
||||
if len(shape_1) != len(shape_2):
|
||||
return False
|
||||
else:
|
||||
for axis, value in enumerate(shape_1):
|
||||
if value is not None and shape_2[axis] not in {value, None}:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_shapes(signal_int_shape, proto_int_shape):
|
||||
if len(signal_int_shape) < 4:
|
||||
raise ValueError(
|
||||
"The number of signal dimensions must be >=4. You provide: " +
|
||||
str(len(signal_int_shape)))
|
||||
|
||||
if len(proto_int_shape) < 2:
|
||||
raise ValueError(
|
||||
"The number of proto dimensions must be >=2. You provide: " +
|
||||
str(len(proto_int_shape)))
|
||||
|
||||
if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]):
|
||||
raise ValueError(
|
||||
"The atom shape of signals must be equal protos. You provide: signals.shape[3:]="
|
||||
+ str(signal_int_shape[3:]) + " != protos.shape[1:]=" +
|
||||
str(proto_int_shape[1:]))
|
||||
|
||||
# not a sparse signal
|
||||
if signal_int_shape[1] != 1:
|
||||
if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]):
|
||||
raise ValueError(
|
||||
"If the signal is not sparse, the number of prototypes must be equal in signals and "
|
||||
"protos. You provide: " + str(signal_int_shape[1]) + " != " +
|
||||
str(proto_int_shape[0]))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _int_and_mixed_shape(tensor):
|
||||
shape = mixed_shape(tensor)
|
||||
int_shape = tuple(i if isinstance(i, int) else None for i in shape)
|
||||
|
||||
return shape, int_shape
|
@ -1,107 +0,0 @@
|
||||
"""ProtoTorch initialization functions."""
|
||||
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
|
||||
INITIALIZERS = dict()
|
||||
|
||||
|
||||
def register_initializer(function):
|
||||
"""Add the initializer to the registry."""
|
||||
INITIALIZERS[function.__name__] = function
|
||||
return function
|
||||
|
||||
|
||||
def labels_from(distribution, one_hot=True):
|
||||
"""Takes a distribution tensor and returns a labels tensor."""
|
||||
num_classes = distribution.shape[0]
|
||||
llist = [[i] * n for i, n in zip(range(num_classes), distribution)]
|
||||
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
||||
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
||||
plabels = torch.tensor(flat_llist, requires_grad=False)
|
||||
if one_hot:
|
||||
return torch.eye(num_classes)[plabels]
|
||||
return plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.ones(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.zeros(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.rand(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.randn(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(num_protos, pdim)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
for i, label in enumerate(plabels):
|
||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||
if one_hot:
|
||||
num_classes = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
xl = x_train[matcher]
|
||||
mean_xl = torch.mean(xl, dim=0)
|
||||
protos[i] = mean_xl
|
||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def stratified_random(x_train,
|
||||
y_train,
|
||||
prototype_distribution,
|
||||
one_hot=True,
|
||||
epsilon=1e-7):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(num_protos, pdim)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
for i, label in enumerate(plabels):
|
||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||
if one_hot:
|
||||
num_classes = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
xl = x_train[matcher]
|
||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||
random_xl = xl[rand_index]
|
||||
protos[i] = random_xl + epsilon
|
||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
def get_initializer(funcname):
|
||||
"""Deserialize the initializer."""
|
||||
if callable(funcname):
|
||||
return funcname
|
||||
if funcname in INITIALIZERS:
|
||||
return INITIALIZERS.get(funcname)
|
||||
raise NameError(f"Initializer {funcname} was not found.")
|
@ -1,35 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def orthogonalization(tensors):
|
||||
r""" Orthogonalization of a given tensor via polar decomposition.
|
||||
"""
|
||||
u, _, v = torch.svd(tensors, compute_uv=True)
|
||||
u_shape = tuple(list(u.shape))
|
||||
v_shape = tuple(list(v.shape))
|
||||
|
||||
# reshape to (num x N x M)
|
||||
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
|
||||
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
|
||||
|
||||
out = u @ v.permute([0, 2, 1])
|
||||
|
||||
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def trace_normalization(tensors):
|
||||
r""" Trace normalization
|
||||
"""
|
||||
epsilon = torch.tensor([1e-10], dtype=torch.float64)
|
||||
# Scope trace_normalization
|
||||
constant = torch.trace(tensors)
|
||||
|
||||
if epsilon != 0:
|
||||
constant = torch.max(constant, epsilon)
|
||||
|
||||
return tensors / constant
|
@ -1,32 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
# Functions
|
||||
def gaussian(distances, variance):
|
||||
return torch.exp(-(distances * distances) / (2 * variance))
|
||||
|
||||
|
||||
def rank_scaled_gaussian(distances, lambd):
|
||||
order = torch.argsort(distances, dim=1)
|
||||
ranks = torch.argsort(order, dim=1)
|
||||
|
||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||
|
||||
|
||||
# Modules
|
||||
class GaussianPrior(torch.nn.Module):
|
||||
def __init__(self, variance):
|
||||
super().__init__()
|
||||
self.variance = variance
|
||||
|
||||
def forward(self, distances):
|
||||
return gaussian(distances, self.variance)
|
||||
|
||||
|
||||
class RankScaledGaussianPrior(torch.nn.Module):
|
||||
def __init__(self, lambd):
|
||||
super().__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
def forward(self, distances):
|
||||
return rank_scaled_gaussian(distances, self.lambd)
|
@ -1,5 +0,0 @@
|
||||
"""ProtoTorch modules."""
|
||||
|
||||
from .competitions import *
|
||||
from .pooling import *
|
||||
from .wrappers import LambdaLayer, LossLayer
|
@ -1,42 +0,0 @@
|
||||
"""ProtoTorch Competition Modules."""
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.competitions import knnc, wtac
|
||||
|
||||
|
||||
class WTAC(torch.nn.Module):
|
||||
"""Winner-Takes-All-Competition Layer.
|
||||
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, distances, labels):
|
||||
return wtac(distances, labels)
|
||||
|
||||
|
||||
class LTAC(torch.nn.Module):
|
||||
"""Loser-Takes-All-Competition Layer.
|
||||
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, probs, labels):
|
||||
return wtac(-1.0 * probs, labels)
|
||||
|
||||
|
||||
class KNNC(torch.nn.Module):
|
||||
"""K-Nearest-Neighbors-Competition.
|
||||
|
||||
Thin wrapper over the `knnc` function.
|
||||
|
||||
"""
|
||||
def __init__(self, k=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.k = k
|
||||
|
||||
def forward(self, distances, labels):
|
||||
return knnc(distances, labels, k=self.k)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"k: {self.k}"
|
@ -1,59 +0,0 @@
|
||||
"""ProtoTorch losses."""
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.losses import glvq_loss
|
||||
|
||||
|
||||
class GLVQLoss(torch.nn.Module):
|
||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
self.beta = torch.tensor(beta)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
||||
|
||||
|
||||
class NeuralGasEnergy(torch.nn.Module):
|
||||
def __init__(self, lm, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.lm = lm
|
||||
|
||||
def forward(self, d):
|
||||
order = torch.argsort(d, dim=1)
|
||||
ranks = torch.argsort(order, dim=1)
|
||||
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
|
||||
|
||||
return cost, order
|
||||
|
||||
def extra_repr(self):
|
||||
return f"lambda: {self.lm}"
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, lm):
|
||||
return torch.exp(-rankings / lm)
|
||||
|
||||
|
||||
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
||||
def __init__(self, topology_layer, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topology_layer = topology_layer
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, topology):
|
||||
winner = rankings[:, 0]
|
||||
|
||||
weights = torch.zeros_like(rankings, dtype=torch.float)
|
||||
weights[torch.arange(rankings.shape[0]), winner] = 1.0
|
||||
|
||||
neighbours = topology.get_neighbours(winner)
|
||||
|
||||
weights[neighbours] = 0.1
|
||||
|
||||
return weights
|
@ -1,170 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||
from prototorch.functions.distances import euclidean_distance_matrix
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
r""" Generalized Tangent Learning Vector Quantization
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes: int
|
||||
Number of classes of the given classification problem.
|
||||
|
||||
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
|
||||
Subspace data for the point approximation, required
|
||||
|
||||
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
|
||||
prototype data for initalization of the prototypes used in GTLVQ.
|
||||
|
||||
subspace_size: int (default=256,optional)
|
||||
Subspace dimension of the Projectors. Currently only supported
|
||||
with tagnent_projection_type=global.
|
||||
|
||||
tangent_projection_type: string
|
||||
Specifies the tangent projection type
|
||||
options: local
|
||||
local_proj
|
||||
global
|
||||
local: computes the tangent distances without emphasizing projected
|
||||
data. Only distances are available
|
||||
local_proj: computs tangent distances and returns the projected data
|
||||
for further use. Be careful: data is repeated by number of prototypes
|
||||
global: Number of subspaces is set to one and every prototypes
|
||||
uses the same.
|
||||
|
||||
prototypes_per_class: int (default=2,optional)
|
||||
Number of prototypes per class
|
||||
|
||||
feature_dim: int (default=256)
|
||||
Dimensionality of the feature space specified as integer.
|
||||
Prototype dimension.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The GTLVQ [1] is a prototype-based classification learning model. The
|
||||
GTLVQ uses the Tangent-Distances for a local point approximation
|
||||
of an assumed data manifold via prototypial representations.
|
||||
|
||||
The GTLVQ requires subspace projectors for transforming the data
|
||||
and prototypes into the affine subspace. Every prototype is
|
||||
equipped with a specific subpspace and represents a point
|
||||
approximation of the assumed manifold.
|
||||
|
||||
In practice prototypes and data are projected on this manifold
|
||||
and pairwise euclidean distance computes.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
|
||||
in classification based on manifolc. models and its relation
|
||||
to tangent metric learning. In: 2017 International Joint
|
||||
Conference on Neural Networks (IJCNN).
|
||||
Bd. 2017-May : IEEE, 2017, S. 1756–1765
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
subspace_data=None,
|
||||
prototype_data=None,
|
||||
subspace_size=256,
|
||||
tangent_projection_type="local",
|
||||
prototypes_per_class=2,
|
||||
feature_dim=256,
|
||||
):
|
||||
super(GTLVQ, self).__init__()
|
||||
|
||||
self.num_protos = num_classes * prototypes_per_class
|
||||
self.num_protos_class = prototypes_per_class
|
||||
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
||||
self.feature_dim = feature_dim
|
||||
self.num_classes = num_classes
|
||||
|
||||
cls_initializer = StratifiedMeanInitializer(prototype_data)
|
||||
cls_distribution = {
|
||||
"num_classes": num_classes,
|
||||
"prototypes_per_class": prototypes_per_class,
|
||||
}
|
||||
|
||||
self.cls = LabeledComponents(cls_distribution, cls_initializer)
|
||||
|
||||
if subspace_data is None:
|
||||
raise ValueError("Init Data must be specified!")
|
||||
|
||||
self.tpt = tangent_projection_type
|
||||
with torch.no_grad():
|
||||
if self.tpt == "local":
|
||||
self.init_local_subspace(subspace_data, subspace_size,
|
||||
self.num_protos)
|
||||
elif self.tpt == "global":
|
||||
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||
else:
|
||||
self.subspaces = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.tpt == "local":
|
||||
dis = self.local_tangent_distances(x)
|
||||
elif self.tpt == "gloabl":
|
||||
dis = self.global_tangent_distances(x)
|
||||
else:
|
||||
dis = (x @ self.cls.prototypes.T) / (
|
||||
torch.norm(x, dim=1, keepdim=True) @ torch.norm(
|
||||
self.cls.prototypes, dim=1, keepdim=True).T)
|
||||
return dis
|
||||
|
||||
def init_gobal_subspace(self, data, num_subspaces):
|
||||
_, _, v = torch.svd(data)
|
||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
subspaces = subspace[:, :num_subspaces]
|
||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||
|
||||
def init_local_subspace(self, data, num_subspaces, num_protos):
|
||||
data = data - torch.mean(data, dim=0)
|
||||
_, _, v = torch.svd(data, some=False)
|
||||
v = v[:, :num_subspaces]
|
||||
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
|
||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||
|
||||
def global_tangent_distances(self, x):
|
||||
# Tangent Projection
|
||||
x, projected_prototypes = (
|
||||
x @ self.subspaces,
|
||||
self.cls.prototypes @ self.subspaces,
|
||||
)
|
||||
# Euclidean Distance
|
||||
return euclidean_distance_matrix(x, projected_prototypes)
|
||||
|
||||
def local_tangent_distances(self, x):
|
||||
|
||||
# Tangent Distance
|
||||
x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components,
|
||||
x.size(-1))
|
||||
protos = self.cls()[0].unsqueeze(0).expand(x.size(0),
|
||||
self.cls.num_components,
|
||||
x.size(-1))
|
||||
projectors = torch.eye(
|
||||
self.subspaces.shape[-2], device=x.device) - torch.bmm(
|
||||
self.subspaces, self.subspaces.permute([0, 2, 1]))
|
||||
diff = (x - protos)
|
||||
diff = diff.permute([1, 0, 2])
|
||||
diff = torch.bmm(diff, projectors)
|
||||
diff = torch.norm(diff, 2, dim=-1).T
|
||||
return diff
|
||||
|
||||
def get_parameters(self):
|
||||
return {
|
||||
"params": self.cls.components,
|
||||
}, {
|
||||
"params": self.subspaces
|
||||
}
|
||||
|
||||
def orthogonalize_subspace(self):
|
||||
if self.subspaces is not None:
|
||||
with torch.no_grad():
|
||||
ortho_subpsaces = (orthogonalization(self.subspaces)
|
||||
if self.tpt == "global" else
|
||||
torch.nn.init.orthogonal_(self.subspaces))
|
||||
self.subspaces.copy_(ortho_subpsaces)
|
@ -1,32 +0,0 @@
|
||||
"""ProtoTorch Pooling Modules."""
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.pooling import (stratified_max_pooling,
|
||||
stratified_min_pooling,
|
||||
stratified_prod_pooling,
|
||||
stratified_sum_pooling)
|
||||
|
||||
|
||||
class StratifiedSumPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_sum_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedProdPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_prod_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMinPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_min_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMaxPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_max_pooling(values, labels)
|
4
prototorch/nn/__init__.py
Normal file
4
prototorch/nn/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""ProtoTorch Neural Network Module"""
|
||||
|
||||
from .activations import *
|
||||
from .wrappers import *
|
@ -1,4 +1,4 @@
|
||||
"""ProtoTorch activation functions."""
|
||||
"""ProtoTorch activations"""
|
||||
|
||||
import torch
|
||||
|
||||
@ -57,6 +57,10 @@ def get_activation(funcname):
|
||||
"""Deserialize the activation function."""
|
||||
if callable(funcname):
|
||||
return funcname
|
||||
if funcname in ACTIVATIONS:
|
||||
elif funcname in ACTIVATIONS:
|
||||
return ACTIVATIONS.get(funcname)
|
||||
raise NameError(f"Activation {funcname} was not found.")
|
||||
else:
|
||||
emsg = f"Unable to find matching function for `{funcname}` " \
|
||||
f"in `prototorch.nn.activations`. "
|
||||
helpmsg = f"Possible values are {list(ACTIVATIONS.keys())}."
|
||||
raise NameError(emsg + helpmsg)
|
@ -1,4 +1,4 @@
|
||||
"""ProtoTorch Wrappers."""
|
||||
"""ProtoTorch wrappers."""
|
||||
|
||||
import torch
|
||||
|
@ -0,0 +1,8 @@
|
||||
"""ProtoFlow utils module"""
|
||||
|
||||
from .colors import hex_to_rgb, rgb_to_hex
|
||||
from .utils import (
|
||||
mesh2d,
|
||||
parse_data_arg,
|
||||
parse_distribution,
|
||||
)
|
@ -1,46 +0,0 @@
|
||||
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
from matplotlib.animation import ArtistAnimation
|
||||
from matplotlib.artist import Artist
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
__version__ = "0.2.0"
|
||||
|
||||
|
||||
class Camera:
|
||||
"""Make animations easier."""
|
||||
def __init__(self, figure: Figure) -> None:
|
||||
"""Create camera from matplotlib figure."""
|
||||
self._figure = figure
|
||||
# need to keep track off artists for each axis
|
||||
self._offsets: Dict[str, Dict[int, int]] = {
|
||||
k: defaultdict(int)
|
||||
for k in
|
||||
["collections", "patches", "lines", "texts", "artists", "images"]
|
||||
}
|
||||
self._photos: List[List[Artist]] = []
|
||||
|
||||
def snap(self) -> List[Artist]:
|
||||
"""Capture current state of the figure."""
|
||||
frame_artists: List[Artist] = []
|
||||
for i, axis in enumerate(self._figure.axes):
|
||||
if axis.legend_ is not None:
|
||||
axis.add_artist(axis.legend_)
|
||||
for name in self._offsets:
|
||||
new_artists = getattr(axis, name)[self._offsets[name][i]:]
|
||||
frame_artists += new_artists
|
||||
self._offsets[name][i] += len(new_artists)
|
||||
self._photos.append(frame_artists)
|
||||
return frame_artists
|
||||
|
||||
def animate(self, *args, **kwargs) -> ArtistAnimation:
|
||||
"""Animate the snapshots taken.
|
||||
Uses matplotlib.animation.ArtistAnimation
|
||||
Returns
|
||||
-------
|
||||
ArtistAnimation
|
||||
"""
|
||||
return ArtistAnimation(self._figure, self._photos, *args, **kwargs)
|
@ -1,78 +1,15 @@
|
||||
"""ProtoFlow color utilities."""
|
||||
|
||||
import matplotlib.lines as mlines
|
||||
from matplotlib import cm
|
||||
from matplotlib.colors import Normalize, to_hex, to_rgb
|
||||
"""ProtoFlow color utilities"""
|
||||
|
||||
|
||||
def color_scheme(n,
|
||||
cmap="viridis",
|
||||
form="hex",
|
||||
tikz=False,
|
||||
zero_indexed=False):
|
||||
"""Return *n* colors from the color scheme.
|
||||
|
||||
Arguments:
|
||||
n (int): number of colors to return
|
||||
|
||||
Keyword Arguments:
|
||||
cmap (str): Name of a matplotlib `colormap\
|
||||
<https://matplotlib.org/3.1.1/gallery/color/colormap_reference.html>`_.
|
||||
form (str): Colorformat (supports "hex" and "rgb").
|
||||
tikz (bool): Output as `TikZ <https://github.com/pgf-tikz/pgf>`_
|
||||
command.
|
||||
zero_indexed (bool): Use zero indexing for output array.
|
||||
|
||||
Returns:
|
||||
(list): List of colors
|
||||
"""
|
||||
cmap = cm.get_cmap(cmap)
|
||||
colornorm = Normalize(vmin=1, vmax=n)
|
||||
hex_map = dict()
|
||||
rgb_map = dict()
|
||||
for cl in range(1, n + 1):
|
||||
if zero_indexed:
|
||||
hex_map[cl - 1] = to_hex(cmap(colornorm(cl)))
|
||||
rgb_map[cl - 1] = to_rgb(cmap(colornorm(cl)))
|
||||
else:
|
||||
hex_map[cl] = to_hex(cmap(colornorm(cl)))
|
||||
rgb_map[cl] = to_rgb(cmap(colornorm(cl)))
|
||||
if tikz:
|
||||
for k, v in rgb_map.items():
|
||||
print(f"\\definecolor{{color-{k}}}{{rgb}}{{{v[0]},{v[1]},{v[2]}}}")
|
||||
if form == "hex":
|
||||
return hex_map
|
||||
elif form == "rgb":
|
||||
return rgb_map
|
||||
else:
|
||||
return hex_map
|
||||
def hex_to_rgb(hex_values):
|
||||
for v in hex_values:
|
||||
v = v.lstrip('#')
|
||||
lv = len(v)
|
||||
c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)]
|
||||
yield c
|
||||
|
||||
|
||||
def get_legend_handles(labels, marker="dots", zero_indexed=False):
|
||||
"""Return matplotlib legend handles and colors."""
|
||||
handles = list()
|
||||
n = len(labels)
|
||||
colors = color_scheme(n,
|
||||
cmap="viridis",
|
||||
form="hex",
|
||||
zero_indexed=zero_indexed)
|
||||
for label, color in zip(labels, colors.values()):
|
||||
if marker == "dots":
|
||||
handle = mlines.Line2D(
|
||||
[],
|
||||
[],
|
||||
color="white",
|
||||
markerfacecolor=color,
|
||||
marker="o",
|
||||
markersize=10,
|
||||
markeredgecolor="k",
|
||||
label=label,
|
||||
)
|
||||
else:
|
||||
handle = mlines.Line2D([], [],
|
||||
color=color,
|
||||
marker="",
|
||||
markersize=15,
|
||||
label=label)
|
||||
handles.append(handle)
|
||||
return handles, colors
|
||||
def rgb_to_hex(rgb_values):
|
||||
for v in rgb_values:
|
||||
c = "%02x%02x%02x" % tuple(v)
|
||||
yield c
|
||||
|
104
prototorch/utils/utils.py
Normal file
104
prototorch/utils/utils.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""ProtoFlow utilities"""
|
||||
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
||||
if x is not None:
|
||||
x_shift = border * np.ptp(x[:, 0])
|
||||
y_shift = border * np.ptp(x[:, 1])
|
||||
x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
|
||||
y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
|
||||
else:
|
||||
x_min, x_max = -border, border
|
||||
y_min, y_max = -border, border
|
||||
xx, yy = np.meshgrid(np.linspace(x_min, x_max, resolution),
|
||||
np.linspace(y_min, y_max, resolution))
|
||||
mesh = np.c_[xx.ravel(), yy.ravel()]
|
||||
return mesh, xx, yy
|
||||
|
||||
|
||||
def distribution_from_list(list_dist: list[int],
|
||||
clabels: Iterable[int] = None):
|
||||
clabels = clabels or list(range(len(list_dist)))
|
||||
distribution = dict(zip(clabels, list_dist))
|
||||
return distribution
|
||||
|
||||
|
||||
def parse_distribution(user_distribution,
|
||||
clabels: Iterable[int] = None) -> dict[int, int]:
|
||||
"""Parse user-provided distribution.
|
||||
|
||||
Return a dictionary with integer keys that represent the class labels and
|
||||
values that denote the number of components/prototypes with that class
|
||||
label.
|
||||
|
||||
The argument `user_distribution` could be any one of a number of allowed
|
||||
formats. If it is a Python list, it is assumed that there are as many
|
||||
entries in this list as there are classes, and the value at each index of
|
||||
this list describes the number of prototypes for that particular class. So,
|
||||
[1, 1, 1] implies that we have three classes with one prototype per class.
|
||||
If it is a Python tuple, a shorthand of (num_classes, prototypes_per_class)
|
||||
is assumed. If it is a Python dictionary, the key-value pairs describe the
|
||||
class label and the number of prototypes for that class respectively. So,
|
||||
{0: 2, 1: 2, 2: 2} implies that we have three classes with labels {1, 2,
|
||||
3}, each equipped with two prototypes. If however, the dictionary contains
|
||||
the keys "num_classes" and "per_class", they are parsed to use their values
|
||||
as one might expect.
|
||||
|
||||
"""
|
||||
if isinstance(user_distribution, dict):
|
||||
if "num_classes" in user_distribution.keys():
|
||||
num_classes = int(user_distribution["num_classes"])
|
||||
per_class = int(user_distribution["per_class"])
|
||||
return distribution_from_list([per_class] * num_classes, clabels)
|
||||
else:
|
||||
return user_distribution
|
||||
elif isinstance(user_distribution, tuple):
|
||||
assert len(user_distribution) == 2
|
||||
num_classes, per_class = user_distribution
|
||||
num_classes, per_class = int(num_classes), int(per_class)
|
||||
return distribution_from_list([per_class] * num_classes, clabels)
|
||||
elif isinstance(user_distribution, list):
|
||||
return distribution_from_list(user_distribution, clabels)
|
||||
else:
|
||||
msg = f"`distribution` was not understood." \
|
||||
f"You have provided: {user_distribution}."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
||||
"""Return data and target as torch tensors."""
|
||||
if isinstance(data_arg, Dataset):
|
||||
if hasattr(data_arg, "__len__"):
|
||||
ds_size = len(data_arg) # type: ignore
|
||||
loader = DataLoader(data_arg, batch_size=ds_size)
|
||||
data, targets = next(iter(loader))
|
||||
else:
|
||||
emsg = f"Dataset {data_arg} is not sized (`__len__` unimplemented)."
|
||||
raise TypeError(emsg)
|
||||
|
||||
elif isinstance(data_arg, DataLoader):
|
||||
data = torch.tensor([])
|
||||
targets = torch.tensor([])
|
||||
for x, y in data_arg:
|
||||
data = torch.cat([data, x])
|
||||
targets = torch.cat([targets, y])
|
||||
else:
|
||||
assert len(data_arg) == 2
|
||||
data, targets = data_arg
|
||||
if not isinstance(data, torch.Tensor):
|
||||
wmsg = f"Converting data to {torch.Tensor}..."
|
||||
warnings.warn(wmsg)
|
||||
data = torch.Tensor(data)
|
||||
if not isinstance(targets, torch.LongTensor):
|
||||
wmsg = f"Converting targets to {torch.LongTensor}..."
|
||||
warnings.warn(wmsg)
|
||||
targets = torch.LongTensor(targets)
|
||||
return data, targets
|
15
setup.cfg
Normal file
15
setup.cfg
Normal file
@ -0,0 +1,15 @@
|
||||
[pylint]
|
||||
disable =
|
||||
too-many-arguments,
|
||||
too-few-public-methods,
|
||||
fixme,
|
||||
|
||||
[pycodestyle]
|
||||
max-line-length = 79
|
||||
|
||||
[isort]
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = True
|
||||
force_grid_wrap = 3
|
||||
use_parentheses = True
|
||||
line_length = 79
|
3
setup.py
3
setup.py
@ -61,8 +61,9 @@ setup(
|
||||
license="MIT",
|
||||
install_requires=INSTALL_REQUIRES,
|
||||
extras_require={
|
||||
"docs": DOCS,
|
||||
"datasets": DATASETS,
|
||||
"dev": DEV,
|
||||
"docs": DOCS,
|
||||
"examples": EXAMPLES,
|
||||
"tests": TESTS,
|
||||
"all": ALL,
|
||||
|
@ -1,26 +0,0 @@
|
||||
"""ProtoTorch components test suite."""
|
||||
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
def test_labcomps_zeros_init():
|
||||
protos = torch.zeros(3, 2)
|
||||
c = pt.components.LabeledComponents(
|
||||
distribution=[1, 1, 1],
|
||||
initializer=pt.components.Zeros(2),
|
||||
)
|
||||
assert (c.components == protos).any() == True
|
||||
|
||||
|
||||
def test_labcomps_warmstart():
|
||||
protos = torch.randn(3, 2)
|
||||
plabels = torch.tensor([1, 2, 3])
|
||||
c = pt.components.LabeledComponents(
|
||||
distribution=[1, 1, 1],
|
||||
initializer=None,
|
||||
initialized_components=[protos, plabels],
|
||||
)
|
||||
assert (c.components == protos).any() == True
|
||||
assert (c.component_labels == plabels).any() == True
|
760
tests/test_core.py
Normal file
760
tests/test_core.py
Normal file
@ -0,0 +1,760 @@
|
||||
"""ProtoTorch core test suite"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
from prototorch.utils import parse_distribution
|
||||
|
||||
|
||||
# Utils
|
||||
def test_parse_distribution_dict_0():
|
||||
distribution = {"num_classes": 1, "per_class": 0}
|
||||
distribution = parse_distribution(distribution)
|
||||
assert distribution == {0: 0}
|
||||
|
||||
|
||||
def test_parse_distribution_dict_1():
|
||||
distribution = dict(num_classes=3, per_class=2)
|
||||
distribution = parse_distribution(distribution)
|
||||
assert distribution == {0: 2, 1: 2, 2: 2}
|
||||
|
||||
|
||||
def test_parse_distribution_dict_2():
|
||||
distribution = {0: 1, 2: 2, -1: 3}
|
||||
distribution = parse_distribution(distribution)
|
||||
assert distribution == {0: 1, 2: 2, -1: 3}
|
||||
|
||||
|
||||
def test_parse_distribution_tuple():
|
||||
distribution = (2, 3)
|
||||
distribution = parse_distribution(distribution)
|
||||
assert distribution == {0: 3, 1: 3}
|
||||
|
||||
|
||||
def test_parse_distribution_list():
|
||||
distribution = [1, 1, 0, 2]
|
||||
distribution = parse_distribution(distribution)
|
||||
assert distribution == {0: 1, 1: 1, 2: 0, 3: 2}
|
||||
|
||||
|
||||
def test_parse_distribution_custom_labels():
|
||||
distribution = [1, 1, 0, 2]
|
||||
clabels = [1, 2, 5, 3]
|
||||
distribution = parse_distribution(distribution, clabels)
|
||||
assert distribution == {1: 1, 2: 1, 5: 0, 3: 2}
|
||||
|
||||
|
||||
# Components initializers
|
||||
def test_literal_comp_generate():
|
||||
protos = torch.rand(4, 3, 5, 5)
|
||||
c = pt.initializers.LiteralCompInitializer(protos)
|
||||
components = c.generate([])
|
||||
assert torch.allclose(components, protos)
|
||||
|
||||
|
||||
def test_literal_comp_generate_from_list():
|
||||
protos = [[0, 1], [2, 3], [4, 5]]
|
||||
c = pt.initializers.LiteralCompInitializer(protos)
|
||||
with pytest.warns(UserWarning):
|
||||
components = c.generate([])
|
||||
assert torch.allclose(components, torch.Tensor(protos))
|
||||
|
||||
|
||||
def test_shape_aware_raises_error():
|
||||
with pytest.raises(TypeError):
|
||||
_ = pt.initializers.ShapeAwareCompInitializer(shape=(2, ))
|
||||
|
||||
|
||||
def test_data_aware_comp_generate():
|
||||
protos = torch.rand(4, 3, 5, 5)
|
||||
c = pt.initializers.DataAwareCompInitializer(protos)
|
||||
components = c.generate(num_components="IgnoreMe!")
|
||||
assert torch.allclose(components, protos)
|
||||
|
||||
|
||||
def test_class_aware_comp_generate():
|
||||
protos = torch.rand(4, 2, 3, 5, 5)
|
||||
plabels = torch.tensor([0, 0, 1, 1]).long()
|
||||
c = pt.initializers.ClassAwareCompInitializer([protos, plabels])
|
||||
components = c.generate(distribution=[])
|
||||
assert torch.allclose(components, protos)
|
||||
|
||||
|
||||
def test_zeros_comp_generate():
|
||||
shape = (3, 5, 5)
|
||||
c = pt.initializers.ZerosCompInitializer(shape)
|
||||
components = c.generate(num_components=4)
|
||||
assert torch.allclose(components, torch.zeros(4, 3, 5, 5))
|
||||
|
||||
|
||||
def test_ones_comp_generate():
|
||||
c = pt.initializers.OnesCompInitializer(2)
|
||||
components = c.generate(num_components=3)
|
||||
assert torch.allclose(components, torch.ones(3, 2))
|
||||
|
||||
|
||||
def test_fill_value_comp_generate():
|
||||
c = pt.initializers.FillValueCompInitializer(2, 0.0)
|
||||
components = c.generate(num_components=3)
|
||||
assert torch.allclose(components, torch.zeros(3, 2))
|
||||
|
||||
|
||||
def test_uniform_comp_generate_min_max_bound():
|
||||
c = pt.initializers.UniformCompInitializer(2, -1.0, 1.0)
|
||||
components = c.generate(num_components=1024)
|
||||
assert components.min() >= -1.0
|
||||
assert components.max() <= 1.0
|
||||
|
||||
|
||||
def test_random_comp_generate_mean():
|
||||
c = pt.initializers.RandomNormalCompInitializer(2, -1.0)
|
||||
components = c.generate(num_components=1024)
|
||||
assert torch.allclose(components.mean(),
|
||||
torch.tensor(-1.0),
|
||||
rtol=1e-05,
|
||||
atol=1e-01)
|
||||
|
||||
|
||||
def test_comp_generate_0_components():
|
||||
c = pt.initializers.ZerosCompInitializer(2)
|
||||
_ = c.generate(num_components=0)
|
||||
|
||||
|
||||
def test_stratified_mean_comp_generate():
|
||||
# yapf: disable
|
||||
x = torch.Tensor(
|
||||
[[0, -1, -2],
|
||||
[10, 11, 12],
|
||||
[0, 0, 0],
|
||||
[2, 2, 2]])
|
||||
y = torch.LongTensor([0, 0, 1, 1])
|
||||
desired = torch.Tensor(
|
||||
[[5.0, 5.0, 5.0],
|
||||
[1.0, 1.0, 1.0]])
|
||||
# yapf: enable
|
||||
c = pt.initializers.StratifiedMeanCompInitializer(data=[x, y])
|
||||
actual = c.generate([1, 1])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_stratified_selection_comp_generate():
|
||||
# yapf: disable
|
||||
x = torch.Tensor(
|
||||
[[0, 0, 0],
|
||||
[1, 1, 1],
|
||||
[0, 0, 0],
|
||||
[1, 1, 1]])
|
||||
y = torch.LongTensor([0, 1, 0, 1])
|
||||
desired = torch.Tensor(
|
||||
[[0, 0, 0],
|
||||
[1, 1, 1]])
|
||||
# yapf: enable
|
||||
c = pt.initializers.StratifiedSelectionCompInitializer(data=[x, y])
|
||||
actual = c.generate([1, 1])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
# Labels initializers
|
||||
def test_literal_labels_init():
|
||||
l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2])
|
||||
with pytest.warns(UserWarning):
|
||||
labels = l.generate([])
|
||||
assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2]))
|
||||
|
||||
|
||||
def test_labels_init_from_list():
|
||||
l = pt.initializers.LabelsInitializer()
|
||||
components = l.generate(distribution=[1, 1, 1])
|
||||
assert torch.allclose(components, torch.LongTensor([0, 1, 2]))
|
||||
|
||||
|
||||
def test_labels_init_from_tuple_legal():
|
||||
l = pt.initializers.LabelsInitializer()
|
||||
components = l.generate(distribution=(3, 1))
|
||||
assert torch.allclose(components, torch.LongTensor([0, 1, 2]))
|
||||
|
||||
|
||||
def test_labels_init_from_tuple_illegal():
|
||||
l = pt.initializers.LabelsInitializer()
|
||||
with pytest.raises(AssertionError):
|
||||
_ = l.generate(distribution=(1, 1, 1))
|
||||
|
||||
|
||||
def test_data_aware_labels_init():
|
||||
data, targets = [0, 1, 2, 3], [0, 0, 1, 1]
|
||||
ds = pt.datasets.NumpyDataset(data, targets)
|
||||
l = pt.initializers.DataAwareLabelsInitializer(ds)
|
||||
labels = l.generate([])
|
||||
assert torch.allclose(labels, torch.LongTensor(targets))
|
||||
|
||||
|
||||
# Reasonings initializers
|
||||
def test_literal_reasonings_init():
|
||||
r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2])
|
||||
with pytest.warns(UserWarning):
|
||||
reasonings = r.generate([])
|
||||
assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2]))
|
||||
|
||||
|
||||
def test_random_reasonings_init():
|
||||
r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8)
|
||||
reasonings = r.generate(distribution=[0, 1])
|
||||
assert torch.numel(reasonings) == 1 * 2 * 2
|
||||
assert reasonings.min() >= 0.2
|
||||
assert reasonings.max() <= 0.8
|
||||
|
||||
|
||||
def test_zeros_reasonings_init():
|
||||
r = pt.initializers.ZerosReasoningsInitializer()
|
||||
reasonings = r.generate(distribution=[0, 1])
|
||||
assert torch.allclose(reasonings, torch.zeros(1, 2, 2))
|
||||
|
||||
|
||||
def test_ones_reasonings_init():
|
||||
r = pt.initializers.ZerosReasoningsInitializer()
|
||||
reasonings = r.generate(distribution=[1, 2, 3])
|
||||
assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
|
||||
|
||||
|
||||
def test_pure_positive_reasonings_init_one_per_class():
|
||||
r = pt.initializers.PurePositiveReasoningsInitializer(
|
||||
components_first=False)
|
||||
reasonings = r.generate(distribution=(4, 1))
|
||||
assert torch.allclose(reasonings[0], torch.eye(4))
|
||||
|
||||
|
||||
def test_pure_positive_reasonings_init_unrepresented_classes():
|
||||
r = pt.initializers.PurePositiveReasoningsInitializer()
|
||||
reasonings = r.generate(distribution=[9, 0, 0, 0])
|
||||
assert reasonings.shape[0] == 9
|
||||
assert reasonings.shape[1] == 4
|
||||
assert reasonings.shape[2] == 2
|
||||
|
||||
|
||||
def test_random_reasonings_init_channels_not_first():
|
||||
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
|
||||
reasonings = r.generate(distribution=[0, 0, 0, 1])
|
||||
assert reasonings.shape[0] == 2
|
||||
assert reasonings.shape[1] == 4
|
||||
assert reasonings.shape[2] == 1
|
||||
|
||||
|
||||
# Transform initializers
|
||||
def test_eye_transform_init_square():
|
||||
t = pt.initializers.EyeTransformInitializer()
|
||||
I = t.generate(3, 3)
|
||||
assert torch.allclose(I, torch.eye(3))
|
||||
|
||||
|
||||
def test_eye_transform_init_narrow():
|
||||
t = pt.initializers.EyeTransformInitializer()
|
||||
actual = t.generate(3, 2)
|
||||
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_eye_transform_init_wide():
|
||||
t = pt.initializers.EyeTransformInitializer()
|
||||
actual = t.generate(2, 3)
|
||||
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
# Transforms
|
||||
def test_linear_transform():
|
||||
l = pt.transforms.LinearTransform(2, 4)
|
||||
actual = l.weights
|
||||
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_linear_transform_zeros_init():
|
||||
l = pt.transforms.LinearTransform(
|
||||
in_dim=2,
|
||||
out_dim=4,
|
||||
initializer=pt.initializers.ZerosLinearTransformInitializer(),
|
||||
)
|
||||
actual = l.weights
|
||||
desired = torch.zeros(2, 4)
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_linear_transform_out_dim_first():
|
||||
l = pt.transforms.LinearTransform(
|
||||
in_dim=2,
|
||||
out_dim=4,
|
||||
initializer=pt.initializers.OLTI(out_dim_first=True),
|
||||
)
|
||||
assert l.weights.shape[0] == 4
|
||||
assert l.weights.shape[1] == 2
|
||||
|
||||
|
||||
# Components
|
||||
def test_components_no_initializer():
|
||||
with pytest.raises(TypeError):
|
||||
_ = pt.components.Components(3, None)
|
||||
|
||||
|
||||
def test_components_no_num_components():
|
||||
with pytest.raises(TypeError):
|
||||
_ = pt.components.Components(initializer=pt.initializers.OCI(2))
|
||||
|
||||
|
||||
def test_components_none_num_components():
|
||||
with pytest.raises(TypeError):
|
||||
_ = pt.components.Components(None, initializer=pt.initializers.OCI(2))
|
||||
|
||||
|
||||
def test_components_no_args():
|
||||
with pytest.raises(TypeError):
|
||||
_ = pt.components.Components()
|
||||
|
||||
|
||||
def test_components_zeros_init():
|
||||
c = pt.components.Components(3, pt.initializers.ZCI(2))
|
||||
assert torch.allclose(c.components, torch.zeros(3, 2))
|
||||
|
||||
|
||||
def test_labeled_components_dict_init():
|
||||
c = pt.components.LabeledComponents({0: 3}, pt.initializers.OCI(2))
|
||||
assert torch.allclose(c.components, torch.ones(3, 2))
|
||||
assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long))
|
||||
|
||||
|
||||
def test_labeled_components_list_init():
|
||||
c = pt.components.LabeledComponents([3], pt.initializers.OCI(2))
|
||||
assert torch.allclose(c.components, torch.ones(3, 2))
|
||||
assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long))
|
||||
|
||||
|
||||
def test_labeled_components_tuple_init():
|
||||
c = pt.components.LabeledComponents({0: 1, 1: 2}, pt.initializers.OCI(2))
|
||||
assert torch.allclose(c.components, torch.ones(3, 2))
|
||||
assert torch.allclose(c.labels, torch.LongTensor([0, 1, 1]))
|
||||
|
||||
|
||||
# Labels
|
||||
def test_standalone_labels_dict_init():
|
||||
l = pt.components.Labels({0: 3})
|
||||
assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long))
|
||||
|
||||
|
||||
def test_standalone_labels_list_init():
|
||||
l = pt.components.Labels([3])
|
||||
assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long))
|
||||
|
||||
|
||||
def test_standalone_labels_tuple_init():
|
||||
l = pt.components.Labels({0: 1, 1: 2})
|
||||
assert torch.allclose(l.labels, torch.LongTensor([0, 1, 1]))
|
||||
|
||||
|
||||
# Losses
|
||||
def test_glvq_loss_int_labels():
|
||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||
labels = torch.tensor([0, 1])
|
||||
targets = torch.ones(100)
|
||||
batch_loss = pt.losses.glvq_loss(distances=d,
|
||||
target_labels=targets,
|
||||
prototype_labels=labels)
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
assert loss_value == -100
|
||||
|
||||
|
||||
def test_glvq_loss_one_hot_labels():
|
||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||
labels = torch.tensor([[0, 1], [1, 0]])
|
||||
wl = torch.tensor([1, 0])
|
||||
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||
batch_loss = pt.losses.glvq_loss(distances=d,
|
||||
target_labels=targets,
|
||||
prototype_labels=labels)
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
assert loss_value == -100
|
||||
|
||||
|
||||
def test_glvq_loss_one_hot_unequal():
|
||||
dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
|
||||
d = torch.stack(dlist, dim=1)
|
||||
labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
|
||||
wl = torch.tensor([1, 0])
|
||||
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||
batch_loss = pt.losses.glvq_loss(distances=d,
|
||||
target_labels=targets,
|
||||
prototype_labels=labels)
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
assert loss_value == -100
|
||||
|
||||
|
||||
# Activations
|
||||
class TestActivations(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.flist = ["identity", "sigmoid_beta", "swish_beta"]
|
||||
self.x = torch.randn(1024, 1)
|
||||
|
||||
def test_registry(self):
|
||||
self.assertIsNotNone(pt.nn.ACTIVATIONS)
|
||||
|
||||
def test_funcname_deserialization(self):
|
||||
for funcname in self.flist:
|
||||
f = pt.nn.get_activation(funcname)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
|
||||
def test_callable_deserialization(self):
|
||||
def dummy(x, **kwargs):
|
||||
return x
|
||||
|
||||
for f in [dummy, lambda x: x]:
|
||||
f = pt.nn.get_activation(f)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
self.assertEqual(1, f(1))
|
||||
|
||||
def test_unknown_deserialization(self):
|
||||
for funcname in ["blubb", "foobar"]:
|
||||
with self.assertRaises(NameError):
|
||||
_ = pt.nn.get_activation(funcname)
|
||||
|
||||
def test_identity(self):
|
||||
actual = pt.nn.identity(self.x)
|
||||
desired = self.x
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_sigmoid_beta1(self):
|
||||
actual = pt.nn.sigmoid_beta(self.x, beta=1.0)
|
||||
desired = torch.sigmoid(self.x)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_swish_beta1(self):
|
||||
actual = pt.nn.swish_beta(self.x, beta=1.0)
|
||||
desired = self.x * torch.sigmoid(self.x)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
del self.x
|
||||
|
||||
|
||||
# Competitions
|
||||
class TestCompetitions(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_wtac(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
competition_layer = pt.competitions.WTAC()
|
||||
actual = competition_layer(d, labels)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_wtac_unequal_dist(self):
|
||||
d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
|
||||
labels = torch.tensor([0, 1, 1])
|
||||
competition_layer = pt.competitions.WTAC()
|
||||
actual = competition_layer(d, labels)
|
||||
desired = torch.tensor([0, 1])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_wtac_one_hot(self):
|
||||
d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
|
||||
labels = torch.tensor([[0, 1], [1, 0]])
|
||||
competition_layer = pt.competitions.WTAC()
|
||||
actual = competition_layer(d, labels)
|
||||
desired = torch.tensor([[0, 1], [1, 0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_knnc_k1(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
competition_layer = pt.competitions.KNNC(k=1)
|
||||
actual = competition_layer(d, labels)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
|
||||
# Pooling
|
||||
class TestPooling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_stratified_min(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
pooling_layer = pt.pooling.StratifiedMinPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
pooling_layer = pt.pooling.StratifiedMinPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min_trivial(self):
|
||||
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 1, 2])
|
||||
pooling_layer = pt.pooling.StratifiedMinPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_max(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
pooling_layer = pt.pooling.StratifiedMaxPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_max_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 2, 1, 0])
|
||||
labels = torch.nn.functional.one_hot(labels, num_classes=3)
|
||||
pooling_layer = pt.pooling.StratifiedMaxPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_sum(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.LongTensor([0, 0, 1, 2])
|
||||
pooling_layer = pt.pooling.StratifiedSumPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_sum_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
pooling_layer = pt.pooling.StratifiedSumPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_prod(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
pooling_layer = pt.pooling.StratifiedProdPooling()
|
||||
actual = pooling_layer(d, labels)
|
||||
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
|
||||
# Distances
|
||||
class TestDistances(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.nx, self.mx = 32, 2048
|
||||
self.ny, self.my = 8, 2048
|
||||
self.x = torch.randn(self.nx, self.mx)
|
||||
self.y = torch.randn(self.ny, self.my)
|
||||
|
||||
def test_manhattan(self):
|
||||
actual = pt.distances.lpnorm_distance(self.x, self.y, p=1)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=1,
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_euclidean(self):
|
||||
actual = pt.distances.euclidean_distance(self.x, self.y)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=3)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_squared_euclidean(self):
|
||||
actual = pt.distances.squared_euclidean_distance(self.x, self.y)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)**2)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lpnorm_p0(self):
|
||||
actual = pt.distances.lpnorm_distance(self.x, self.y, p=0)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=0,
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=4)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lpnorm_p2(self):
|
||||
actual = pt.distances.lpnorm_distance(self.x, self.y, p=2)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=4)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lpnorm_p3(self):
|
||||
actual = pt.distances.lpnorm_distance(self.x, self.y, p=3)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=3,
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=4)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lpnorm_pinf(self):
|
||||
actual = pt.distances.lpnorm_distance(self.x, self.y, p=float("inf"))
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=float("inf"),
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=4)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_omega_identity(self):
|
||||
omega = torch.eye(self.mx, self.my)
|
||||
actual = pt.distances.omega_distance(self.x, self.y, omega=omega)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)**2)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lomega_identity(self):
|
||||
omega = torch.eye(self.mx, self.my)
|
||||
omegas = torch.stack([omega for _ in range(self.ny)], dim=0)
|
||||
actual = pt.distances.lomega_distance(self.x, self.y, omegas=omegas)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)**2)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
del self.x, self.y
|
@ -1,32 +1,97 @@
|
||||
"""ProtoTorch datasets test suite."""
|
||||
"""ProtoTorch datasets test suite"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.datasets import abstract, tecator
|
||||
import prototorch as pt
|
||||
from prototorch.datasets.abstract import Dataset, ProtoDataset
|
||||
|
||||
|
||||
class TestAbstract(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.ds = Dataset("./artifacts")
|
||||
|
||||
def test_getitem(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
abstract.Dataset("./artifacts")[0]
|
||||
_ = self.ds[0]
|
||||
|
||||
def test_len(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
len(abstract.Dataset("./artifacts"))
|
||||
_ = len(self.ds)
|
||||
|
||||
def tearDown(self):
|
||||
del self.ds
|
||||
|
||||
|
||||
class TestProtoDataset(unittest.TestCase):
|
||||
def test_getitem(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
abstract.ProtoDataset("./artifacts")[0]
|
||||
|
||||
def test_download(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
abstract.ProtoDataset("./artifacts").download()
|
||||
_ = ProtoDataset("./artifacts", download=True)
|
||||
|
||||
def test_exists(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = ProtoDataset("./artifacts", download=False)
|
||||
|
||||
|
||||
class TestNumpyDataset(unittest.TestCase):
|
||||
def test_list_init(self):
|
||||
ds = pt.datasets.NumpyDataset([1], [1])
|
||||
self.assertEqual(len(ds), 1)
|
||||
|
||||
def test_numpy_init(self):
|
||||
data = np.random.randn(3, 2)
|
||||
targets = np.array([0, 1, 2])
|
||||
ds = pt.datasets.NumpyDataset(data, targets)
|
||||
self.assertEqual(len(ds), 3)
|
||||
|
||||
|
||||
class TestSpiral(unittest.TestCase):
|
||||
def test_init(self):
|
||||
ds = pt.datasets.Spiral(num_samples=10)
|
||||
self.assertEqual(len(ds), 10)
|
||||
|
||||
|
||||
class TestIris(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.ds = pt.datasets.Iris()
|
||||
|
||||
def test_size(self):
|
||||
self.assertEqual(len(self.ds), 150)
|
||||
|
||||
def test_dims(self):
|
||||
self.assertEqual(self.ds.data.shape[1], 4)
|
||||
|
||||
def test_dims_selection(self):
|
||||
ds = pt.datasets.Iris(dims=[0, 1])
|
||||
self.assertEqual(ds.data.shape[1], 2)
|
||||
|
||||
|
||||
class TestBlobs(unittest.TestCase):
|
||||
def test_size(self):
|
||||
ds = pt.datasets.Blobs(num_samples=10)
|
||||
self.assertEqual(len(ds), 10)
|
||||
|
||||
|
||||
class TestRandom(unittest.TestCase):
|
||||
def test_size(self):
|
||||
ds = pt.datasets.Random(num_samples=10)
|
||||
self.assertEqual(len(ds), 10)
|
||||
|
||||
|
||||
class TestCircles(unittest.TestCase):
|
||||
def test_size(self):
|
||||
ds = pt.datasets.Circles(num_samples=10)
|
||||
self.assertEqual(len(ds), 10)
|
||||
|
||||
|
||||
class TestMoons(unittest.TestCase):
|
||||
def test_size(self):
|
||||
ds = pt.datasets.Moons(num_samples=10)
|
||||
self.assertEqual(len(ds), 10)
|
||||
|
||||
|
||||
class TestTecator(unittest.TestCase):
|
||||
@ -42,25 +107,25 @@ class TestTecator(unittest.TestCase):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
self._remove_artifacts()
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = tecator.Tecator(rootdir, download=False)
|
||||
_ = pt.datasets.Tecator(rootdir, download=False)
|
||||
|
||||
def test_download_caching(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
||||
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
||||
_ = pt.datasets.Tecator(rootdir, download=True, verbose=False)
|
||||
_ = pt.datasets.Tecator(rootdir, download=False, verbose=False)
|
||||
|
||||
def test_repr(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
||||
train = pt.datasets.Tecator(rootdir, download=True, verbose=True)
|
||||
self.assertTrue("Split: Train" in train.__repr__())
|
||||
|
||||
def test_download_train(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
train = tecator.Tecator(root=rootdir,
|
||||
train = pt.datasets.Tecator(root=rootdir,
|
||||
train=True,
|
||||
download=True,
|
||||
verbose=False)
|
||||
train = tecator.Tecator(root=rootdir, download=True, verbose=False)
|
||||
train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False)
|
||||
x_train, y_train = train.data, train.targets
|
||||
self.assertEqual(x_train.shape[0], 144)
|
||||
self.assertEqual(y_train.shape[0], 144)
|
||||
@ -68,7 +133,7 @@ class TestTecator(unittest.TestCase):
|
||||
|
||||
def test_download_test(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||
x_test, y_test = test.data, test.targets
|
||||
self.assertEqual(x_test.shape[0], 71)
|
||||
self.assertEqual(y_test.shape[0], 71)
|
||||
@ -76,20 +141,20 @@ class TestTecator(unittest.TestCase):
|
||||
|
||||
def test_class_to_idx(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||
_ = test.class_to_idx
|
||||
|
||||
def test_getitem(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||
x, y = test[0]
|
||||
self.assertEqual(x.shape[0], 100)
|
||||
self.assertIsInstance(y, int)
|
||||
|
||||
def test_loadable_with_dataloader(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
self._remove_artifacts()
|
||||
|
@ -1,581 +0,0 @@
|
||||
"""ProtoTorch functions test suite."""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions import (activations, competitions, distances,
|
||||
initializers, losses, pooling)
|
||||
|
||||
|
||||
class TestActivations(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.flist = ["identity", "sigmoid_beta", "swish_beta"]
|
||||
self.x = torch.randn(1024, 1)
|
||||
|
||||
def test_registry(self):
|
||||
self.assertIsNotNone(activations.ACTIVATIONS)
|
||||
|
||||
def test_funcname_deserialization(self):
|
||||
for funcname in self.flist:
|
||||
f = activations.get_activation(funcname)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
|
||||
# def test_torch_script(self):
|
||||
# for funcname in self.flist:
|
||||
# f = activations.get_activation(funcname)
|
||||
# self.assertIsInstance(f, torch.jit.ScriptFunction)
|
||||
|
||||
def test_callable_deserialization(self):
|
||||
def dummy(x, **kwargs):
|
||||
return x
|
||||
|
||||
for f in [dummy, lambda x: x]:
|
||||
f = activations.get_activation(f)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
self.assertEqual(1, f(1))
|
||||
|
||||
def test_unknown_deserialization(self):
|
||||
for funcname in ["blubb", "foobar"]:
|
||||
with self.assertRaises(NameError):
|
||||
_ = activations.get_activation(funcname)
|
||||
|
||||
def test_identity(self):
|
||||
actual = activations.identity(self.x)
|
||||
desired = self.x
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_sigmoid_beta1(self):
|
||||
actual = activations.sigmoid_beta(self.x, beta=1.0)
|
||||
desired = torch.sigmoid(self.x)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_swish_beta1(self):
|
||||
actual = activations.swish_beta(self.x, beta=1.0)
|
||||
desired = self.x * torch.sigmoid(self.x)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
del self.x
|
||||
|
||||
|
||||
class TestCompetitions(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_wtac(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
actual = competitions.wtac(d, labels)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_wtac_unequal_dist(self):
|
||||
d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
|
||||
labels = torch.tensor([0, 1, 1])
|
||||
actual = competitions.wtac(d, labels)
|
||||
desired = torch.tensor([0, 1])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_wtac_one_hot(self):
|
||||
d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
|
||||
labels = torch.tensor([[0, 1], [1, 0]])
|
||||
actual = competitions.wtac(d, labels)
|
||||
desired = torch.tensor([[0, 1], [1, 0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_knnc_k1(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
actual = competitions.knnc(d, labels, k=1)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestPooling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_stratified_min(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
actual = pooling.stratified_min_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
actual = pooling.stratified_min_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min_trivial(self):
|
||||
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 1, 2])
|
||||
actual = pooling.stratified_min_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_max(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
actual = pooling.stratified_max_pooling(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_max_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 2, 1, 0])
|
||||
labels = torch.nn.functional.one_hot(labels, num_classes=3)
|
||||
actual = pooling.stratified_max_pooling(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_sum(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.LongTensor([0, 0, 1, 2])
|
||||
actual = pooling.stratified_sum_pooling(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_sum_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
actual = pooling.stratified_sum_pooling(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_prod(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
actual = pooling.stratified_prod_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestDistances(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.nx, self.mx = 32, 2048
|
||||
self.ny, self.my = 8, 2048
|
||||
self.x = torch.randn(self.nx, self.mx)
|
||||
self.y = torch.randn(self.ny, self.my)
|
||||
|
||||
def test_manhattan(self):
|
||||
actual = distances.lpnorm_distance(self.x, self.y, p=1)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=1,
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_euclidean(self):
|
||||
actual = distances.euclidean_distance(self.x, self.y)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=3)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_squared_euclidean(self):
|
||||
actual = distances.squared_euclidean_distance(self.x, self.y)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)**2)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lpnorm_p0(self):
|
||||
actual = distances.lpnorm_distance(self.x, self.y, p=0)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=0,
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=4)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lpnorm_p2(self):
|
||||
actual = distances.lpnorm_distance(self.x, self.y, p=2)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=4)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lpnorm_p3(self):
|
||||
actual = distances.lpnorm_distance(self.x, self.y, p=3)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=3,
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=4)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lpnorm_pinf(self):
|
||||
actual = distances.lpnorm_distance(self.x, self.y, p=float("inf"))
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=float("inf"),
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=4)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_omega_identity(self):
|
||||
omega = torch.eye(self.mx, self.my)
|
||||
actual = distances.omega_distance(self.x, self.y, omega=omega)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)**2)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lomega_identity(self):
|
||||
omega = torch.eye(self.mx, self.my)
|
||||
omegas = torch.stack([omega for _ in range(self.ny)], dim=0)
|
||||
actual = distances.lomega_distance(self.x, self.y, omegas=omegas)
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)**2)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
del self.x, self.y
|
||||
|
||||
|
||||
class TestInitializers(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.flist = [
|
||||
"zeros",
|
||||
"ones",
|
||||
"rand",
|
||||
"randn",
|
||||
"stratified_mean",
|
||||
"stratified_random",
|
||||
]
|
||||
self.x = torch.tensor(
|
||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
||||
dtype=torch.float32)
|
||||
self.y = torch.tensor([0, 0, 1, 1])
|
||||
self.gen = torch.manual_seed(42)
|
||||
|
||||
def test_registry(self):
|
||||
self.assertIsNotNone(initializers.INITIALIZERS)
|
||||
|
||||
def test_funcname_deserialization(self):
|
||||
for funcname in self.flist:
|
||||
f = initializers.get_initializer(funcname)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
|
||||
def test_callable_deserialization(self):
|
||||
def dummy(x):
|
||||
return x
|
||||
|
||||
for f in [dummy, lambda x: x]:
|
||||
f = initializers.get_initializer(f)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
self.assertEqual(1, f(1))
|
||||
|
||||
def test_unknown_deserialization(self):
|
||||
for funcname in ["blubb", "foobar"]:
|
||||
with self.assertRaises(NameError):
|
||||
_ = initializers.get_initializer(funcname)
|
||||
|
||||
def test_zeros(self):
|
||||
pdist = torch.tensor([1, 1])
|
||||
actual, _ = initializers.zeros(self.x, self.y, pdist)
|
||||
desired = torch.zeros(2, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_ones(self):
|
||||
pdist = torch.tensor([1, 1])
|
||||
actual, _ = initializers.ones(self.x, self.y, pdist)
|
||||
desired = torch.ones(2, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_rand(self):
|
||||
pdist = torch.tensor([1, 1])
|
||||
actual, _ = initializers.rand(self.x, self.y, pdist)
|
||||
desired = torch.rand(2, 3, generator=torch.manual_seed(42))
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_randn(self):
|
||||
pdist = torch.tensor([1, 1])
|
||||
actual, _ = initializers.randn(self.x, self.y, pdist)
|
||||
desired = torch.randn(2, 3, generator=torch.manual_seed(42))
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_mean_equal1(self):
|
||||
pdist = torch.tensor([1, 1])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||
desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_random_equal1(self):
|
||||
pdist = torch.tensor([1, 1])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||
False)
|
||||
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_mean_equal2(self):
|
||||
pdist = torch.tensor([2, 2])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||
desired = torch.tensor([[5.0, 5.0, 5.0], [5.0, 5.0, 5.0],
|
||||
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_random_equal2(self):
|
||||
pdist = torch.tensor([2, 2])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||
False)
|
||||
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, -1.0, -2.0],
|
||||
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_mean_unequal(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||
desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_random_unequal(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||
False)
|
||||
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_mean_unequal_one_hot(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
y = torch.eye(2)[self.y]
|
||||
desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
|
||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||
desired1,
|
||||
decimal=5)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual2,
|
||||
desired2,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_random_unequal_one_hot(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
y = torch.eye(2)[self.y]
|
||||
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
|
||||
desired1 = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||
desired1,
|
||||
decimal=5)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual2,
|
||||
desired2,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
del self.x, self.y, self.gen
|
||||
_ = torch.seed()
|
||||
|
||||
|
||||
class TestLosses(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_glvq_loss_int_labels(self):
|
||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||
labels = torch.tensor([0, 1])
|
||||
targets = torch.ones(100)
|
||||
batch_loss = losses.glvq_loss(distances=d,
|
||||
target_labels=targets,
|
||||
prototype_labels=labels)
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
self.assertEqual(loss_value, -100)
|
||||
|
||||
def test_glvq_loss_one_hot_labels(self):
|
||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||
labels = torch.tensor([[0, 1], [1, 0]])
|
||||
wl = torch.tensor([1, 0])
|
||||
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||
batch_loss = losses.glvq_loss(distances=d,
|
||||
target_labels=targets,
|
||||
prototype_labels=labels)
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
self.assertEqual(loss_value, -100)
|
||||
|
||||
def test_glvq_loss_one_hot_unequal(self):
|
||||
dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
|
||||
d = torch.stack(dlist, dim=1)
|
||||
labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
|
||||
wl = torch.tensor([1, 0])
|
||||
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||
batch_loss = losses.glvq_loss(distances=d,
|
||||
target_labels=targets,
|
||||
prototype_labels=labels)
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
self.assertEqual(loss_value, -100)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
47
tests/test_utils.py
Normal file
47
tests/test_utils.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""ProtoTorch utils test suite"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
def test_mesh2d_without_input():
|
||||
mesh, xx, yy = pt.utils.mesh2d(border=2.0, resolution=10)
|
||||
assert mesh.shape[0] == 100
|
||||
assert mesh.shape[1] == 2
|
||||
assert xx.shape[0] == 10
|
||||
assert xx.shape[1] == 10
|
||||
assert yy.shape[0] == 10
|
||||
assert yy.shape[1] == 10
|
||||
assert np.min(xx) == -2.0
|
||||
assert np.max(xx) == 2.0
|
||||
assert np.min(yy) == -2.0
|
||||
assert np.max(yy) == 2.0
|
||||
|
||||
|
||||
def test_mesh2d_with_torch_input():
|
||||
x = 10 * torch.rand(5, 2)
|
||||
mesh, xx, yy = pt.utils.mesh2d(x, border=0.0, resolution=100)
|
||||
assert mesh.shape[0] == 100 * 100
|
||||
assert mesh.shape[1] == 2
|
||||
assert xx.shape[0] == 100
|
||||
assert xx.shape[1] == 100
|
||||
assert yy.shape[0] == 100
|
||||
assert yy.shape[1] == 100
|
||||
assert np.min(xx) == x[:, 0].min()
|
||||
assert np.max(xx) == x[:, 0].max()
|
||||
assert np.min(yy) == x[:, 1].min()
|
||||
assert np.max(yy) == x[:, 1].max()
|
||||
|
||||
|
||||
def test_hex_to_rgb():
|
||||
red_rgb = list(pt.utils.hex_to_rgb(["#ff0000"]))[0]
|
||||
assert red_rgb[0] == 255
|
||||
assert red_rgb[1] == 0
|
||||
assert red_rgb[2] == 0
|
||||
|
||||
|
||||
def test_rgb_to_hex():
|
||||
blue_hex = list(pt.utils.rgb_to_hex([(0, 0, 255)]))[0]
|
||||
assert blue_hex.lower() == "0000ff"
|
Loading…
Reference in New Issue
Block a user