diff --git a/.bumpversion.cfg b/.bumpversion.cfg index ef4285b..1646a5d 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -3,8 +3,8 @@ current_version = 0.5.0 commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+) -serialize = - {major}.{minor}.{patch} +serialize = {major}.{minor}.{patch} +message = bump: {current_version} → {new_version} [bumpversion:file:setup.py] diff --git a/.gitignore b/.gitignore index da6479c..c6b5985 100644 --- a/.gitignore +++ b/.gitignore @@ -129,14 +129,6 @@ dmypy.json # End of https://www.gitignore.io/api/python -# ProtoFlow -core -checkpoint -logs/ -saved_weights/ -scratch* - - # Created by https://www.gitignore.io/api/visualstudiocode # Edit at https://www.gitignore.io/?templates=visualstudiocode @@ -154,5 +146,13 @@ scratch* # End of https://www.gitignore.io/api/visualstudiocode .vscode/ +# Vim +*~ +*.swp +*.swo + +# Artifacts created by ProtoTorch reports artifacts +examples/_*.py +examples/_*.ipynb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cf37edc..a9a8838 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,19 +23,19 @@ repos: - id: isort - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v0.902' + rev: v0.902 hooks: - id: mypy files: prototorch additional_dependencies: [types-pkg_resources] - repo: https://github.com/pre-commit/mirrors-yapf - rev: 'v0.31.0' # Use the sha / tag you want to point at + rev: v0.31.0 hooks: - id: yapf - repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.9.0 # Use the ref you want to point at + rev: v1.9.0 hooks: - id: python-use-type-annotations - id: python-no-log-warn @@ -47,8 +47,8 @@ repos: hooks: - id: pyupgrade -- repo: https://github.com/jorisroovers/gitlint - rev: "v0.15.1" +- repo: https://github.com/si-cim/gitlint + rev: v0.15.2-unofficial hooks: - id: gitlint args: [--contrib=CT1, --ignore=B6, --msg-filename] diff --git a/.remarkrc b/.remarkrc new file mode 100644 index 0000000..5f7b470 --- /dev/null +++ b/.remarkrc @@ -0,0 +1,7 @@ +{ + "plugins": [ + "remark-preset-lint-recommended", + ["remark-lint-list-item-indent", false], + ["no-emphasis-as-header", false] + ] +} diff --git a/.travis.yml b/.travis.yml index 77f8083..6320573 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ dist: bionic sudo: false language: python -python: 3.8 +python: 3.9 cache: directories: - "$HOME/.cache/pip" diff --git a/README.md b/README.md index 5e8aa19..86379ad 100644 --- a/README.md +++ b/README.md @@ -51,14 +51,20 @@ that link not work try . ## Contribution This repository contains definition for [git hooks](https://githooks.com). -[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch. -Please install the hooks by running +[Pre-commit](https://pre-commit.com) is automatically installed as development +dependency with prototorch or you can install it manually with `pip install +pre-commit`. + +Please install the hooks by running: ```bash pre-commit install pre-commit install --hook-type commit-msg ``` before creating the first commit. +The commit will fail if the commit message does not follow the specification +provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification). + ## Bibtex If you would like to cite the package, please use this: diff --git a/examples/cbc_iris.py b/examples/cbc_iris.py new file mode 100644 index 0000000..2264a86 --- /dev/null +++ b/examples/cbc_iris.py @@ -0,0 +1,96 @@ +"""ProtoTorch CBC example using 2D Iris data.""" + +import torch +from matplotlib import pyplot as plt + +import prototorch as pt + + +class CBC(torch.nn.Module): + def __init__(self, data, **kwargs): + super().__init__(**kwargs) + self.components_layer = pt.components.ReasoningComponents( + distribution=[2, 1, 2], + components_initializer=pt.initializers.SSCI(data, noise=0.1), + reasonings_initializer=pt.initializers.PPRI(components_first=True), + ) + + def forward(self, x): + components, reasonings = self.components_layer() + sims = pt.similarities.euclidean_similarity(x, components) + probs = pt.competitions.cbcc(sims, reasonings) + return probs + + +class VisCBC2D(): + def __init__(self, model, data): + self.model = model + self.x_train, self.y_train = pt.utils.parse_data_arg(data) + self.title = "Components Visualization" + self.fig = plt.figure(self.title) + self.border = 0.1 + self.resolution = 100 + self.cmap = "viridis" + + def on_epoch_end(self): + x_train, y_train = self.x_train, self.y_train + _components = self.model.components_layer._components.detach() + ax = self.fig.gca() + ax.cla() + ax.set_title(self.title) + ax.axis("off") + ax.scatter( + x_train[:, 0], + x_train[:, 1], + c=y_train, + cmap=self.cmap, + edgecolor="k", + marker="o", + s=30, + ) + ax.scatter( + _components[:, 0], + _components[:, 1], + c="w", + cmap=self.cmap, + edgecolor="k", + marker="D", + s=50, + ) + x = torch.vstack((x_train, _components)) + mesh_input, xx, yy = pt.utils.mesh2d(x, self.border, self.resolution) + with torch.no_grad(): + y_pred = self.model( + torch.Tensor(mesh_input).type_as(_components)).argmax(1) + y_pred = y_pred.cpu().reshape(xx.shape) + ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) + plt.pause(0.2) + + +if __name__ == "__main__": + train_ds = pt.datasets.Iris(dims=[0, 2]) + + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32) + + model = CBC(train_ds) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + criterion = pt.losses.MarginLoss(margin=0.1) + vis = VisCBC2D(model, train_ds) + + for epoch in range(200): + correct = 0.0 + for x, y in train_loader: + y_oh = torch.eye(3)[y] + y_pred = model(x) + loss = criterion(y_pred, y_oh).mean(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + correct += (y_pred.argmax(1) == y).float().sum(0) + + acc = 100 * correct / len(train_ds) + print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%") + vis.on_epoch_end() diff --git a/examples/new_components.py b/examples/new_components.py index 2cb5f9f..ff47622 100644 --- a/examples/new_components.py +++ b/examples/new_components.py @@ -1,39 +1,35 @@ """This example script shows the usage of the new components architecture. Serialization/deserialization also works as expected. + """ -# DATASET import torch -from sklearn.datasets import load_iris -from sklearn.preprocessing import StandardScaler -scaler = StandardScaler() -x_train, y_train = load_iris(return_X_y=True) -x_train = x_train[:, [0, 2]] -scaler.fit(x_train) -x_train = scaler.transform(x_train) +import prototorch as pt -x_train = torch.Tensor(x_train) -y_train = torch.Tensor(y_train) -num_classes = len(torch.unique(y_train)) +ds = pt.datasets.Iris() -# CREATE NEW COMPONENTS -from prototorch.components import * -from prototorch.components.initializers import * - -unsupervised = Components(6, SelectionInitializer(x_train)) +unsupervised = pt.components.Components( + 6, + initializer=pt.initializers.ZCI(2), +) print(unsupervised()) -prototypes = LabeledComponents( - (3, 2), StratifiedSelectionInitializer(x_train, y_train)) +prototypes = pt.components.LabeledComponents( + (3, 2), + components_initializer=pt.initializers.SSCI(ds), +) print(prototypes()) -components = ReasoningComponents( - (3, 6), StratifiedSelectionInitializer(x_train, y_train)) -print(components()) +components = pt.components.ReasoningComponents( + (3, 2), + components_initializer=pt.initializers.SSCI(ds), + reasonings_initializer=pt.initializers.PPRI(), +) +print(prototypes()) -# TEST SERIALIZATION +# Test Serialization import io save = io.BytesIO() @@ -41,25 +37,20 @@ torch.save(unsupervised, save) save.seek(0) serialized_unsupervised = torch.load(save) -assert torch.all(unsupervised.components == serialized_unsupervised.components - ), "Serialization of Components failed." +assert torch.all(unsupervised.components == serialized_unsupervised.components) save = io.BytesIO() torch.save(prototypes, save) save.seek(0) serialized_prototypes = torch.load(save) -assert torch.all(prototypes.components == serialized_prototypes.components - ), "Serialization of Components failed." -assert torch.all(prototypes.component_labels == serialized_prototypes. - component_labels), "Serialization of Components failed." +assert torch.all(prototypes.components == serialized_prototypes.components) +assert torch.all(prototypes.labels == serialized_prototypes.labels) save = io.BytesIO() torch.save(components, save) save.seek(0) serialized_components = torch.load(save) -assert torch.all(components.components == serialized_components.components - ), "Serialization of Components failed." -assert torch.all(components.reasonings == serialized_components.reasonings - ), "Serialization of Components failed." +assert torch.all(components.components == serialized_components.components) +assert torch.all(components.reasonings == serialized_components.reasonings) diff --git a/prototorch/__init__.py b/prototorch/__init__.py index 79750a0..9ebb609 100644 --- a/prototorch/__init__.py +++ b/prototorch/__init__.py @@ -1,21 +1,41 @@ -"""ProtoTorch package.""" +"""ProtoTorch package""" import pkgutil from typing import List import pkg_resources -from . import components, datasets, functions, modules, utils -from .datasets import * +from . import ( + datasets, + nn, + utils, +) +from .core import ( + competitions, + components, + distances, + initializers, + losses, + pooling, + similarities, + transforms, +) # Core Setup __version__ = "0.5.0" __all_core__ = [ - "datasets", - "functions", - "modules", + "competitions", "components", + "core", + "datasets", + "distances", + "initializers", + "losses", + "nn", + "pooling", + "similarities", + "transforms", "utils", ] diff --git a/prototorch/components/__init__.py b/prototorch/components/__init__.py deleted file mode 100644 index 07dd543..0000000 --- a/prototorch/components/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from prototorch.components.components import * -from prototorch.components.initializers import * diff --git a/prototorch/components/components.py b/prototorch/components/components.py deleted file mode 100644 index c7565d5..0000000 --- a/prototorch/components/components.py +++ /dev/null @@ -1,270 +0,0 @@ -"""ProtoTorch components modules.""" - -import warnings - -import torch -from torch.nn.parameter import Parameter - -from prototorch.components.initializers import (ClassAwareInitializer, - ComponentsInitializer, - EqualLabelsInitializer, - UnequalLabelsInitializer, - ZeroReasoningsInitializer) - -from .initializers import parse_data_arg - - -def get_labels_object(distribution): - if isinstance(distribution, dict): - if "num_classes" in distribution.keys(): - labels = EqualLabelsInitializer( - distribution["num_classes"], - distribution["prototypes_per_class"]) - else: - clabels = list(distribution.keys()) - dist = list(distribution.values()) - labels = UnequalLabelsInitializer(dist, clabels) - elif isinstance(distribution, tuple): - num_classes, prototypes_per_class = distribution - labels = EqualLabelsInitializer(num_classes, prototypes_per_class) - elif isinstance(distribution, list): - labels = UnequalLabelsInitializer(distribution) - else: - msg = f"`distribution` not understood." \ - f"You have provided: {distribution=}." - raise ValueError(msg) - return labels - - -def _precheck_initializer(initializer): - if not isinstance(initializer, ComponentsInitializer): - emsg = f"`initializer` has to be some subtype of " \ - f"{ComponentsInitializer}. " \ - f"You have provided: {initializer=} instead." - raise TypeError(emsg) - - -class LinearMapping(torch.nn.Module): - """LinearMapping is a learnable Mapping Matrix.""" - def __init__(self, - mapping_shape=None, - initializer=None, - *, - initialized_linearmapping=None): - super().__init__() - - # Ignore all initialization settings if initialized_components is given. - if initialized_linearmapping is not None: - self._register_mapping(initialized_linearmapping) - if num_components is not None or initializer is not None: - wmsg = "Arguments ignored while initializing Components" - warnings.warn(wmsg) - else: - self._initialize_mapping(mapping_shape, initializer) - - @property - def mapping_shape(self): - return self._omega.shape - - def _register_mapping(self, components): - self.register_parameter("_omega", Parameter(components)) - - def _initialize_mapping(self, mapping_shape, initializer): - _precheck_initializer(initializer) - _mapping = initializer.generate(mapping_shape) - self._register_mapping(_mapping) - - @property - def mapping(self): - """Tensor containing the component tensors.""" - return self._omega.detach() - - def forward(self): - return self._omega - - -class Components(torch.nn.Module): - """Components is a set of learnable Tensors.""" - def __init__(self, - num_components=None, - initializer=None, - *, - initialized_components=None): - super().__init__() - - # Ignore all initialization settings if initialized_components is given. - if initialized_components is not None: - self._register_components(initialized_components) - if num_components is not None or initializer is not None: - wmsg = "Arguments ignored while initializing Components" - warnings.warn(wmsg) - else: - self._initialize_components(num_components, initializer) - - @property - def num_components(self): - return len(self._components) - - def _register_components(self, components): - self.register_parameter("_components", Parameter(components)) - - def _initialize_components(self, num_components, initializer): - _precheck_initializer(initializer) - _components = initializer.generate(num_components) - self._register_components(_components) - - def add_components(self, - num=1, - initializer=None, - *, - initialized_components=None): - if initialized_components is not None: - _components = torch.cat([self._components, initialized_components]) - else: - _precheck_initializer(initializer) - _new = initializer.generate(num) - _components = torch.cat([self._components, _new]) - self._register_components(_components) - - def remove_components(self, indices=None): - mask = torch.ones(self.num_components, dtype=torch.bool) - mask[indices] = False - _components = self._components[mask] - self._register_components(_components) - return mask - - @property - def components(self): - """Tensor containing the component tensors.""" - return self._components.detach() - - def forward(self): - return self._components - - def extra_repr(self): - return f"(components): (shape: {tuple(self._components.shape)})" - - -class LabeledComponents(Components): - """LabeledComponents generate a set of components and a set of labels. - - Every Component has a label assigned. - """ - def __init__(self, - distribution=None, - initializer=None, - *, - initialized_components=None): - if initialized_components is not None: - components, component_labels = parse_data_arg( - initialized_components) - super().__init__(initialized_components=components) - self._register_labels(component_labels) - else: - labels = get_labels_object(distribution) - self.initial_distribution = labels.distribution - _labels = labels.generate() - super().__init__(len(_labels), initializer=initializer) - self._register_labels(_labels) - - def _register_labels(self, labels): - self.register_buffer("_labels", labels) - - @property - def distribution(self): - clabels, counts = torch.unique(self._labels, - sorted=True, - return_counts=True) - return dict(zip(clabels.tolist(), counts.tolist())) - - def _initialize_components(self, num_components, initializer): - if isinstance(initializer, ClassAwareInitializer): - _precheck_initializer(initializer) - _components = initializer.generate(num_components, - self.initial_distribution) - self._register_components(_components) - else: - super()._initialize_components(num_components, initializer) - - def add_components(self, distribution, initializer): - _precheck_initializer(initializer) - - # Labels - labels = get_labels_object(distribution) - new_labels = labels.generate() - _labels = torch.cat([self._labels, new_labels]) - self._register_labels(_labels) - - # Components - if isinstance(initializer, ClassAwareInitializer): - _new = initializer.generate(len(new_labels), distribution) - else: - _new = initializer.generate(len(new_labels)) - _components = torch.cat([self._components, _new]) - self._register_components(_components) - - def remove_components(self, indices=None): - # Components - mask = super().remove_components(indices) - - # Labels - _labels = self._labels[mask] - self._register_labels(_labels) - - @property - def component_labels(self): - """Tensor containing the component tensors.""" - return self._labels.detach() - - def forward(self): - return super().forward(), self._labels - - -class ReasoningComponents(Components): - r"""ReasoningComponents generate a set of components and a set of reasoning matrices. - - Every Component has a reasoning matrix assigned. - - A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The - first element is called positive reasoning :math:`p`, the second negative - reasoning :math:`n`. A components can reason in favour (positive) of a - class, against (negative) a class or not at all (neutral). - - It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 - \leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a - three element probability distribution. - - """ - def __init__(self, - reasonings=None, - initializer=None, - *, - initialized_components=None): - if initialized_components is not None: - components, reasonings = initialized_components - - super().__init__(initialized_components=components) - self.register_parameter("_reasonings", reasonings) - else: - self._initialize_reasonings(reasonings) - super().__init__(len(self._reasonings), initializer=initializer) - - def _initialize_reasonings(self, reasonings): - if isinstance(reasonings, tuple): - num_classes, num_components = reasonings - reasonings = ZeroReasoningsInitializer(num_classes, num_components) - - _reasonings = reasonings.generate() - self.register_parameter("_reasonings", _reasonings) - - @property - def reasonings(self): - """Returns Reasoning Matrix. - - Dimension NxCx2 - - """ - return self._reasonings.detach() - - def forward(self): - return super().forward(), self._reasonings diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py deleted file mode 100644 index 2e15b05..0000000 --- a/prototorch/components/initializers.py +++ /dev/null @@ -1,233 +0,0 @@ -"""ProtoTroch Initializers.""" -import warnings -from collections.abc import Iterable -from itertools import chain - -import torch -from torch.utils.data import DataLoader, Dataset - - -def parse_data_arg(data_arg): - if isinstance(data_arg, Dataset): - data_arg = DataLoader(data_arg, batch_size=len(data_arg)) - - if isinstance(data_arg, DataLoader): - data = torch.tensor([]) - targets = torch.tensor([]) - for x, y in data_arg: - data = torch.cat([data, x]) - targets = torch.cat([targets, y]) - else: - data, targets = data_arg - if not isinstance(data, torch.Tensor): - wmsg = f"Converting data to {torch.Tensor}." - warnings.warn(wmsg) - data = torch.Tensor(data) - if not isinstance(targets, torch.Tensor): - wmsg = f"Converting targets to {torch.Tensor}." - warnings.warn(wmsg) - targets = torch.Tensor(targets) - return data, targets - - -def get_subinitializers(data, targets, clabels, subinit_type): - initializers = dict() - for clabel in clabels: - class_data = data[targets == clabel] - class_initializer = subinit_type(class_data) - initializers[clabel] = (class_initializer) - return initializers - - -# Components -class ComponentsInitializer(object): - def generate(self, number_of_components): - raise NotImplementedError("Subclasses should implement this!") - - -class DimensionAwareInitializer(ComponentsInitializer): - def __init__(self, dims): - super().__init__() - if isinstance(dims, Iterable): - self.components_dims = tuple(dims) - else: - self.components_dims = (dims, ) - - -class OnesInitializer(DimensionAwareInitializer): - def __init__(self, dims, scale=1.0): - super().__init__(dims) - self.scale = scale - - def generate(self, length): - gen_dims = (length, ) + self.components_dims - return torch.ones(gen_dims) * self.scale - - -class ZerosInitializer(DimensionAwareInitializer): - def generate(self, length): - gen_dims = (length, ) + self.components_dims - return torch.zeros(gen_dims) - - -class UniformInitializer(DimensionAwareInitializer): - def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0): - super().__init__(dims) - self.minimum = minimum - self.maximum = maximum - self.scale = scale - - def generate(self, length): - gen_dims = (length, ) + self.components_dims - return torch.ones(gen_dims).uniform_(self.minimum, - self.maximum) * self.scale - - -class DataAwareInitializer(ComponentsInitializer): - def __init__(self, data, transform=torch.nn.Identity()): - super().__init__() - self.data = data - self.transform = transform - - def __del__(self): - del self.data - - -class SelectionInitializer(DataAwareInitializer): - def generate(self, length): - indices = torch.LongTensor(length).random_(0, len(self.data)) - return self.transform(self.data[indices]) - - -class MeanInitializer(DataAwareInitializer): - def generate(self, length): - mean = torch.mean(self.data, dim=0) - repeat_dim = [length] + [1] * len(mean.shape) - return self.transform(mean.repeat(repeat_dim)) - - -class ClassAwareInitializer(DataAwareInitializer): - def __init__(self, data, transform=torch.nn.Identity()): - data, targets = parse_data_arg(data) - super().__init__(data, transform) - self.targets = targets - self.clabels = torch.unique(self.targets).int().tolist() - self.num_classes = len(self.clabels) - - def _get_samples_from_initializer(self, length, dist): - if not dist: - per_class = length // self.num_classes - dist = dict(zip(self.clabels, self.num_classes * [per_class])) - if isinstance(dist, list): - dist = dict(zip(self.clabels, dist)) - samples = [self.initializers[k].generate(n) for k, n in dist.items()] - out = torch.vstack(samples) - with torch.no_grad(): - out = self.transform(out) - return out - - def __del__(self): - del self.data - del self.targets - - -class StratifiedMeanInitializer(ClassAwareInitializer): - def __init__(self, data, **kwargs): - super().__init__(data, **kwargs) - self.initializers = get_subinitializers(self.data, self.targets, - self.clabels, MeanInitializer) - - def generate(self, length, dist): - samples = self._get_samples_from_initializer(length, dist) - return samples - - -class StratifiedSelectionInitializer(ClassAwareInitializer): - def __init__(self, data, noise=None, **kwargs): - super().__init__(data, **kwargs) - self.noise = noise - self.initializers = get_subinitializers(self.data, self.targets, - self.clabels, - SelectionInitializer) - - def add_noise_v1(self, x): - return x + self.noise - - def add_noise_v2(self, x): - """Shifts some dimensions of the data randomly.""" - n1 = torch.rand_like(x) - n2 = torch.rand_like(x) - mask = torch.bernoulli(n1) - torch.bernoulli(n2) - return x + (self.noise * mask) - - def generate(self, length, dist): - samples = self._get_samples_from_initializer(length, dist) - if self.noise is not None: - samples = self.add_noise_v1(samples) - return samples - - -# Omega matrix -class PcaInitializer(DataAwareInitializer): - def generate(self, shape): - (input_dim, latent_dim) = shape - (_, eigVal, eigVec) = torch.pca_lowrank(self.data, q=latent_dim) - return eigVec - - -# Labels -class LabelsInitializer: - def generate(self): - raise NotImplementedError("Subclasses should implement this!") - - -class UnequalLabelsInitializer(LabelsInitializer): - def __init__(self, dist, clabels=None): - self.dist = dist - self.clabels = clabels or range(len(self.dist)) - - @property - def distribution(self): - return self.dist - - def generate(self): - targets = list( - chain(*[[i] * n for i, n in zip(self.clabels, self.dist)])) - return torch.LongTensor(targets) - - -class EqualLabelsInitializer(LabelsInitializer): - def __init__(self, classes, per_class): - self.classes = classes - self.per_class = per_class - - @property - def distribution(self): - return self.classes * [self.per_class] - - def generate(self): - return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() - - -# Reasonings -class ReasoningsInitializer: - def generate(self, length): - raise NotImplementedError("Subclasses should implement this!") - - -class ZeroReasoningsInitializer(ReasoningsInitializer): - def __init__(self, classes, length): - self.classes = classes - self.length = length - - def generate(self): - return torch.zeros((self.length, self.classes, 2)) - - -# Aliases -SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer -SMI = StratifiedMeanInitializer -Random = RandomInitializer = UniformInitializer -Zeros = ZerosInitializer -Ones = OnesInitializer -PCA = PcaInitializer diff --git a/prototorch/core/__init__.py b/prototorch/core/__init__.py new file mode 100644 index 0000000..e5961c1 --- /dev/null +++ b/prototorch/core/__init__.py @@ -0,0 +1,10 @@ +"""ProtoTorch core""" + +from .competitions import * +from .components import * +from .distances import * +from .initializers import * +from .losses import * +from .pooling import * +from .similarities import * +from .transforms import * diff --git a/prototorch/core/competitions.py b/prototorch/core/competitions.py new file mode 100644 index 0000000..8ac72e7 --- /dev/null +++ b/prototorch/core/competitions.py @@ -0,0 +1,89 @@ +"""ProtoTorch competitions""" + +import torch + + +def wtac(distances: torch.Tensor, labels: torch.LongTensor): + """Winner-Takes-All-Competition. + + Returns the labels corresponding to the winners. + + """ + winning_indices = torch.min(distances, dim=1).indices + winning_labels = labels[winning_indices].squeeze() + return winning_labels + + +def knnc(distances: torch.Tensor, labels: torch.LongTensor, k: int = 1): + """K-Nearest-Neighbors-Competition. + + Returns the labels corresponding to the winners. + + """ + winning_indices = torch.topk(-distances, k=k, dim=1).indices + winning_labels = torch.mode(labels[winning_indices], dim=1).values + return winning_labels + + +def cbcc(detections: torch.Tensor, reasonings: torch.Tensor): + """Classification-By-Components Competition. + + Returns probability distributions over the classes. + + `detections` must be of shape [batch_size, num_components]. + `reasonings` must be of shape [num_components, num_classes, 2]. + + """ + A, B = reasonings.permute(2, 1, 0).clamp(0, 1) + pk = A + nk = (1 - A) * B + numerator = (detections @ (pk - nk).T) + nk.sum(1) + probs = numerator / (pk + nk).sum(1) + return probs + + +class WTAC(torch.nn.Module): + """Winner-Takes-All-Competition Layer. + + Thin wrapper over the `wtac` function. + + """ + def forward(self, distances, labels): + return wtac(distances, labels) + + +class LTAC(torch.nn.Module): + """Loser-Takes-All-Competition Layer. + + Thin wrapper over the `wtac` function. + + """ + def forward(self, probs, labels): + return wtac(-1.0 * probs, labels) + + +class KNNC(torch.nn.Module): + """K-Nearest-Neighbors-Competition. + + Thin wrapper over the `knnc` function. + + """ + def __init__(self, k=1, **kwargs): + super().__init__(**kwargs) + self.k = k + + def forward(self, distances, labels): + return knnc(distances, labels, k=self.k) + + def extra_repr(self): + return f"k: {self.k}" + + +class CBCC(torch.nn.Module): + """Classification-By-Components Competition. + + Thin wrapper over the `cbcc` function. + + """ + def forward(self, detections, reasonings): + return cbcc(detections, reasonings) diff --git a/prototorch/core/components.py b/prototorch/core/components.py new file mode 100644 index 0000000..e9c6433 --- /dev/null +++ b/prototorch/core/components.py @@ -0,0 +1,370 @@ +"""ProtoTorch components""" + +import inspect +from typing import Union + +import torch +from torch.nn.parameter import Parameter + +from ..utils import parse_distribution +from .initializers import ( + AbstractClassAwareCompInitializer, + AbstractComponentsInitializer, + AbstractLabelsInitializer, + AbstractReasoningsInitializer, + LabelsInitializer, + PurePositiveReasoningsInitializer, + RandomReasoningsInitializer, +) + + +def validate_initializer(initializer, instanceof): + """Check if the initializer is valid.""" + if not isinstance(initializer, instanceof): + emsg = f"`initializer` has to be an instance " \ + f"of some subtype of {instanceof}. " \ + f"You have provided: {initializer} instead. " + helpmsg = "" + if inspect.isclass(initializer): + helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \ + f"with the brackets instead of just {initializer.__name__}?" + raise TypeError(emsg + helpmsg) + return True + + +def gencat(ins, attr, init, *iargs, **ikwargs): + """Generate new items and concatenate with existing items.""" + new_items = init.generate(*iargs, **ikwargs) + if hasattr(ins, attr): + items = torch.cat([getattr(ins, attr), new_items]) + else: + items = new_items + return items, new_items + + +def removeind(ins, attr, indices): + """Remove items at specified indices.""" + mask = torch.ones(len(ins), dtype=torch.bool) + mask[indices] = False + items = getattr(ins, attr)[mask] + return items, mask + + +def get_cikwargs(init, distribution): + """Return appropriate key-word arguments for a component initializer.""" + if isinstance(init, AbstractClassAwareCompInitializer): + cikwargs = dict(distribution=distribution) + else: + distribution = parse_distribution(distribution) + num_components = sum(distribution.values()) + cikwargs = dict(num_components=num_components) + return cikwargs + + +class AbstractComponents(torch.nn.Module): + """Abstract class for all components modules.""" + @property + def num_components(self): + """Current number of components.""" + return len(self._components) + + @property + def components(self): + """Detached Tensor containing the components.""" + return self._components.detach().cpu() + + def _register_components(self, components): + self.register_parameter("_components", Parameter(components)) + + def extra_repr(self): + return f"components: (shape: {tuple(self._components.shape)})" + + def __len__(self): + return self.num_components + + +class Components(AbstractComponents): + """A set of adaptable Tensors.""" + def __init__(self, num_components: int, + initializer: AbstractComponentsInitializer): + super().__init__() + self.add_components(num_components, initializer) + + def add_components(self, num_components: int, + initializer: AbstractComponentsInitializer): + """Generate and add new components.""" + assert validate_initializer(initializer, AbstractComponentsInitializer) + _components, new_components = gencat(self, "_components", initializer, + num_components) + self._register_components(_components) + return new_components + + def remove_components(self, indices): + """Remove components at specified indices.""" + _components, mask = removeind(self, "_components", indices) + self._register_components(_components) + return mask + + def forward(self): + """Simply return the components parameter Tensor.""" + return self._components + + +class AbstractLabels(torch.nn.Module): + """Abstract class for all labels modules.""" + @property + def labels(self): + return self._labels.cpu() + + @property + def num_labels(self): + return len(self._labels) + + @property + def unique_labels(self): + return torch.unique(self._labels) + + @property + def num_unique(self): + return len(self.unique_labels) + + @property + def distribution(self): + unique, counts = torch.unique(self._labels, + sorted=True, + return_counts=True) + return dict(zip(unique.tolist(), counts.tolist())) + + def _register_labels(self, labels): + self.register_buffer("_labels", labels) + + def extra_repr(self): + r = f"num_labels: {self.num_labels}, num_unique: {self.num_unique}" + if len(self.distribution) < 11: # avoid lengthy representations + d = self.distribution + unique, counts = list(d.keys()), list(d.values()) + r += f", unique: {unique}, counts: {counts}" + return r + + def __len__(self): + return self.num_labels + + +class Labels(AbstractLabels): + """A set of standalone labels.""" + def __init__(self, + distribution: Union[dict, list, tuple], + initializer: AbstractLabelsInitializer = LabelsInitializer()): + super().__init__() + self.add_labels(distribution, initializer) + + def add_labels( + self, + distribution: Union[dict, tuple, list], + initializer: AbstractLabelsInitializer = LabelsInitializer()): + """Generate and add new labels.""" + assert validate_initializer(initializer, AbstractLabelsInitializer) + _labels, new_labels = gencat(self, "_labels", initializer, + distribution) + self._register_labels(_labels) + return new_labels + + def remove_labels(self, indices): + """Remove labels at specified indices.""" + _labels, mask = removeind(self, "_labels", indices) + self._register_labels(_labels) + return mask + + def forward(self): + """Simply return the labels.""" + return self._labels + + +class LabeledComponents(AbstractComponents): + """A set of adaptable components and corresponding unadaptable labels.""" + def __init__( + self, + distribution: Union[dict, list, tuple], + components_initializer: AbstractComponentsInitializer, + labels_initializer: AbstractLabelsInitializer = LabelsInitializer()): + super().__init__() + self.add_components(distribution, components_initializer, + labels_initializer) + + @property + def distribution(self): + unique, counts = torch.unique(self._labels, + sorted=True, + return_counts=True) + return dict(zip(unique.tolist(), counts.tolist())) + + @property + def num_classes(self): + return len(self.distribution.keys()) + + @property + def labels(self): + """Tensor containing the component labels.""" + return self._labels.cpu() + + def _register_labels(self, labels): + self.register_buffer("_labels", labels) + + def add_components( + self, + distribution, + components_initializer, + labels_initializer: AbstractLabelsInitializer = LabelsInitializer()): + """Generate and add new components and labels.""" + assert validate_initializer(components_initializer, + AbstractComponentsInitializer) + assert validate_initializer(labels_initializer, + AbstractLabelsInitializer) + cikwargs = get_cikwargs(components_initializer, distribution) + _components, new_components = gencat(self, "_components", + components_initializer, + **cikwargs) + _labels, new_labels = gencat(self, "_labels", labels_initializer, + distribution) + self._register_components(_components) + self._register_labels(_labels) + return new_components, new_labels + + def remove_components(self, indices): + """Remove components and labels at specified indices.""" + _components, mask = removeind(self, "_components", indices) + _labels, mask = removeind(self, "_labels", indices) + self._register_components(_components) + self._register_labels(_labels) + return mask + + def forward(self): + """Simply return the components parameter Tensor and labels.""" + return self._components, self._labels + + +class Reasonings(torch.nn.Module): + """A set of standalone reasoning matrices. + + The `reasonings` tensor is of shape [num_components, num_classes, 2]. + + """ + def __init__( + self, + distribution: Union[dict, list, tuple], + initializer: + AbstractReasoningsInitializer = RandomReasoningsInitializer()): + super().__init__() + + @property + def num_classes(self): + return self._reasonings.shape[1] + + @property + def reasonings(self): + """Tensor containing the reasoning matrices.""" + return self._reasonings.detach().cpu() + + def _register_reasonings(self, reasonings): + self.register_buffer("_reasonings", reasonings) + + def add_reasonings( + self, + distribution: Union[dict, list, tuple], + initializer: + AbstractReasoningsInitializer = RandomReasoningsInitializer()): + """Generate and add new reasonings.""" + assert validate_initializer(initializer, AbstractReasoningsInitializer) + _reasonings, new_reasonings = gencat(self, "_reasonings", initializer, + distribution) + self._register_reasonings(_reasonings) + return new_reasonings + + def remove_reasonings(self, indices): + """Remove reasonings at specified indices.""" + _reasonings, mask = removeind(self, "_reasonings", indices) + self._register_reasonings(_reasonings) + return mask + + def forward(self): + """Simply return the reasonings.""" + return self._reasonings + + +class ReasoningComponents(AbstractComponents): + r"""A set of components and a corresponding adapatable reasoning matrices. + + Every component has its own reasoning matrix. + + A reasoning matrix is an Nx2 matrix, where N is the number of classes. The + first element is called positive reasoning :math:`p`, the second negative + reasoning :math:`n`. A components can reason in favour (positive) of a + class, against (negative) a class or not at all (neutral). + + It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 + \leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a + three element probability distribution. + + """ + def __init__( + self, + distribution: Union[dict, list, tuple], + components_initializer: AbstractComponentsInitializer, + reasonings_initializer: + AbstractReasoningsInitializer = PurePositiveReasoningsInitializer()): + super().__init__() + self.add_components(distribution, components_initializer, + reasonings_initializer) + + @property + def num_classes(self): + return self._reasonings.shape[1] + + @property + def reasonings(self): + """Tensor containing the reasoning matrices.""" + return self._reasonings.detach().cpu() + + @property + def reasoning_matrices(self): + """Reasoning matrices for each class.""" + with torch.no_grad(): + A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1) + pk = A + nk = (1 - pk) * B + ik = 1 - pk - nk + matrices = torch.stack([pk, nk, ik], dim=-1).permute(1, 2, 0) + return matrices.cpu() + + def _register_reasonings(self, reasonings): + self.register_parameter("_reasonings", Parameter(reasonings)) + + def add_components(self, distribution, components_initializer, + reasonings_initializer: AbstractReasoningsInitializer): + """Generate and add new components and reasonings.""" + assert validate_initializer(components_initializer, + AbstractComponentsInitializer) + assert validate_initializer(reasonings_initializer, + AbstractReasoningsInitializer) + cikwargs = get_cikwargs(components_initializer, distribution) + _components, new_components = gencat(self, "_components", + components_initializer, + **cikwargs) + _reasonings, new_reasonings = gencat(self, "_reasonings", + reasonings_initializer, + distribution) + self._register_components(_components) + self._register_reasonings(_reasonings) + return new_components, new_reasonings + + def remove_components(self, indices): + """Remove components and reasonings at specified indices.""" + _components, mask = removeind(self, "_components", indices) + _reasonings, mask = removeind(self, "_reasonings", indices) + self._register_components(_components) + self._register_reasonings(_reasonings) + return mask + + def forward(self): + """Simply return the components and reasonings.""" + return self._components, self._reasonings diff --git a/prototorch/core/distances.py b/prototorch/core/distances.py new file mode 100644 index 0000000..c19a8dc --- /dev/null +++ b/prototorch/core/distances.py @@ -0,0 +1,98 @@ +"""ProtoTorch distances""" + +import torch + + +def squared_euclidean_distance(x, y): + r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`. + + Compute :math:`{\langle \bm x - \bm y \rangle}_2` + + **Alias:** + ``prototorch.functions.distances.sed`` + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + expanded_x = x.unsqueeze(dim=1) + batchwise_difference = y - expanded_x + differences_raised = torch.pow(batchwise_difference, 2) + distances = torch.sum(differences_raised, axis=2) + return distances + + +def euclidean_distance(x, y): + r"""Compute the Euclidean distance between :math:`x` and :math:`y`. + + Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}` + + :returns: Distance Tensor of shape :math:`X \times Y` + :rtype: `torch.tensor` + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + distances_raised = squared_euclidean_distance(x, y) + distances = torch.sqrt(distances_raised) + return distances + + +def euclidean_distance_v2(x, y): + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + diff = y - x.unsqueeze(1) + pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt() + # Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the + # batch diagonal. See: + # https://pytorch.org/docs/stable/generated/torch.diagonal.html + distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1) + # print(f"{diff.shape=}") # (nx, ny, ndim) + # print(f"{pairwise_distances.shape=}") # (nx, ny, ny) + # print(f"{distances.shape=}") # (nx, ny) + return distances + + +def lpnorm_distance(x, y, p): + r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`. + Also known as Minkowski distance. + + Compute :math:`{\| \bm x - \bm y \|}_p`. + + Calls ``torch.cdist`` + + :param p: p parameter of the lp norm + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + distances = torch.cdist(x, y, p=p) + return distances + + +def omega_distance(x, y, omega): + r"""Omega distance. + + Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p` + + :param `torch.tensor` omega: Two dimensional matrix + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + projected_x = x @ omega + projected_y = y @ omega + distances = squared_euclidean_distance(projected_x, projected_y) + return distances + + +def lomega_distance(x, y, omegas): + r"""Localized Omega distance. + + Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p` + + :param `torch.tensor` omegas: Three dimensional matrix + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + projected_x = x @ omegas + projected_y = torch.diagonal(y @ omegas).T + expanded_y = torch.unsqueeze(projected_y, dim=1) + batchwise_difference = expanded_y - projected_x + differences_squared = batchwise_difference**2 + distances = torch.sum(differences_squared, dim=2) + distances = distances.permute(1, 0) + return distances + + +# Aliases +sed = squared_euclidean_distance diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py new file mode 100644 index 0000000..fc5e83f --- /dev/null +++ b/prototorch/core/initializers.py @@ -0,0 +1,494 @@ +"""ProtoTorch code initializers""" + +import warnings +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import ( + Callable, + Type, + Union, +) + +import torch + +from ..utils import parse_data_arg, parse_distribution + + +# Components +class AbstractComponentsInitializer(ABC): + """Abstract class for all components initializers.""" + ... + + +class LiteralCompInitializer(AbstractComponentsInitializer): + """'Generate' the provided components. + + Use this to 'generate' pre-initialized components elsewhere. + + """ + def __init__(self, components): + self.components = components + + def generate(self, num_components: int = 0): + """Ignore `num_components` and simply return `self.components`.""" + if not isinstance(self.components, torch.Tensor): + wmsg = f"Converting components to {torch.Tensor}..." + warnings.warn(wmsg) + self.components = torch.Tensor(self.components) + return self.components + + +class ShapeAwareCompInitializer(AbstractComponentsInitializer): + """Abstract class for all dimension-aware components initializers.""" + def __init__(self, shape: Union[Iterable, int]): + if isinstance(shape, Iterable): + self.component_shape = tuple(shape) + else: + self.component_shape = (shape, ) + + @abstractmethod + def generate(self, num_components: int): + ... + + +class ZerosCompInitializer(ShapeAwareCompInitializer): + """Generate zeros corresponding to the components shape.""" + def generate(self, num_components: int): + components = torch.zeros((num_components, ) + self.component_shape) + return components + + +class OnesCompInitializer(ShapeAwareCompInitializer): + """Generate ones corresponding to the components shape.""" + def generate(self, num_components: int): + components = torch.ones((num_components, ) + self.component_shape) + return components + + +class FillValueCompInitializer(OnesCompInitializer): + """Generate components with the provided `fill_value`.""" + def __init__(self, shape, fill_value: float = 1.0): + super().__init__(shape) + self.fill_value = fill_value + + def generate(self, num_components: int): + ones = super().generate(num_components) + components = ones.fill_(self.fill_value) + return components + + +class UniformCompInitializer(OnesCompInitializer): + """Generate components by sampling from a continuous uniform distribution.""" + def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0): + super().__init__(shape) + self.minimum = minimum + self.maximum = maximum + self.scale = scale + + def generate(self, num_components: int): + ones = super().generate(num_components) + components = self.scale * ones.uniform_(self.minimum, self.maximum) + return components + + +class RandomNormalCompInitializer(OnesCompInitializer): + """Generate components by sampling from a standard normal distribution.""" + def __init__(self, shape, shift=0.0, scale=1.0): + super().__init__(shape) + self.shift = shift + self.scale = scale + + def generate(self, num_components: int): + ones = super().generate(num_components) + components = self.scale * (torch.randn_like(ones) + self.shift) + return components + + +class AbstractDataAwareCompInitializer(AbstractComponentsInitializer): + """Abstract class for all data-aware components initializers. + + Components generated by data-aware components initializers inherit the shape + of the provided data. + + `data` has to be a torch tensor. + + """ + def __init__(self, + data: torch.Tensor, + noise: float = 0.0, + transform: Callable = torch.nn.Identity()): + self.data = data + self.noise = noise + self.transform = transform + + def generate_end_hook(self, samples): + drift = torch.rand_like(samples) * self.noise + components = self.transform(samples + drift) + return components + + @abstractmethod + def generate(self, num_components: int): + ... + return self.generate_end_hook(...) + + def __del__(self): + del self.data + + +class DataAwareCompInitializer(AbstractDataAwareCompInitializer): + """'Generate' the components from the provided data.""" + def generate(self, num_components: int = 0): + """Ignore `num_components` and simply return transformed `self.data`.""" + components = self.generate_end_hook(self.data) + return components + + +class SelectionCompInitializer(AbstractDataAwareCompInitializer): + """Generate components by uniformly sampling from the provided data.""" + def generate(self, num_components: int): + indices = torch.LongTensor(num_components).random_(0, len(self.data)) + samples = self.data[indices] + components = self.generate_end_hook(samples) + return components + + +class MeanCompInitializer(AbstractDataAwareCompInitializer): + """Generate components by computing the mean of the provided data.""" + def generate(self, num_components: int): + mean = self.data.mean(dim=0) + repeat_dim = [num_components] + [1] * len(mean.shape) + samples = mean.repeat(repeat_dim) + components = self.generate_end_hook(samples) + return components + + +class AbstractClassAwareCompInitializer(AbstractComponentsInitializer): + """Abstract class for all class-aware components initializers. + + Components generated by class-aware components initializers inherit the shape + of the provided data. + + `data` could be a torch Dataset or DataLoader or a list/tuple of data and + target tensors. + + """ + def __init__(self, + data, + noise: float = 0.0, + transform: Callable = torch.nn.Identity()): + self.data, self.targets = parse_data_arg(data) + self.noise = noise + self.transform = transform + self.clabels = torch.unique(self.targets).int().tolist() + self.num_classes = len(self.clabels) + + def generate_end_hook(self, samples): + drift = torch.rand_like(samples) * self.noise + components = self.transform(samples + drift) + return components + + @abstractmethod + def generate(self, distribution: Union[dict, list, tuple]): + ... + return self.generate_end_hook(...) + + def __del__(self): + del self.data + del self.targets + + +class ClassAwareCompInitializer(AbstractClassAwareCompInitializer): + """'Generate' components from provided data and requested distribution.""" + def generate(self, distribution: Union[dict, list, tuple]): + """Ignore `distribution` and simply return transformed `self.data`.""" + components = self.generate_end_hook(self.data) + return components + + +class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer): + """Abstract class for all stratified components initializers.""" + @property + @abstractmethod + def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]: + ... + + def generate(self, distribution: Union[dict, list, tuple]): + distribution = parse_distribution(distribution) + components = torch.tensor([]) + for k, v in distribution.items(): + stratified_data = self.data[self.targets == k] + initializer = self.subinit_type( + stratified_data, + noise=self.noise, + transform=self.transform, + ) + samples = initializer.generate(num_components=v) + components = torch.cat([components, samples]) + return components + + +class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer): + """Generate components using stratified sampling from the provided data.""" + @property + def subinit_type(self): + return SelectionCompInitializer + + +class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer): + """Generate components at stratified means of the provided data.""" + @property + def subinit_type(self): + return MeanCompInitializer + + +# Labels +class AbstractLabelsInitializer(ABC): + """Abstract class for all labels initializers.""" + @abstractmethod + def generate(self, distribution: Union[dict, list, tuple]): + ... + + +class LiteralLabelsInitializer(AbstractLabelsInitializer): + """'Generate' the provided labels. + + Use this to 'generate' pre-initialized labels elsewhere. + + """ + def __init__(self, labels): + self.labels = labels + + def generate(self, distribution: Union[dict, list, tuple]): + """Ignore `distribution` and simply return `self.labels`. + + Convert to long tensor, if necessary. + """ + labels = self.labels + if not isinstance(labels, torch.LongTensor): + wmsg = f"Converting labels to {torch.LongTensor}..." + warnings.warn(wmsg) + labels = torch.LongTensor(labels) + return labels + + +class DataAwareLabelsInitializer(AbstractLabelsInitializer): + """'Generate' the labels from a torch Dataset.""" + def __init__(self, data): + self.data, self.targets = parse_data_arg(data) + + def generate(self, distribution: Union[dict, list, tuple]): + """Ignore `num_components` and simply return `self.targets`.""" + return self.targets + + +class LabelsInitializer(AbstractLabelsInitializer): + """Generate labels from `distribution`.""" + def generate(self, distribution: Union[dict, list, tuple]): + distribution = parse_distribution(distribution) + labels_list = [] + for k, v in distribution.items(): + labels_list.extend([k] * v) + labels = torch.LongTensor(labels_list) + return labels + + +class OneHotLabelsInitializer(LabelsInitializer): + """Generate one-hot-encoded labels from `distribution`.""" + def generate(self, distribution: Union[dict, list, tuple]): + distribution = parse_distribution(distribution) + num_classes = len(distribution.keys()) + # this breaks if class labels are not [0,...,nclasses] + labels = torch.eye(num_classes)[super().generate(distribution)] + return labels + + +# Reasonings +class AbstractReasoningsInitializer(ABC): + """Abstract class for all reasonings initializers.""" + def __init__(self, components_first: bool = True): + self.components_first = components_first + + def compute_shape(self, distribution): + distribution = parse_distribution(distribution) + num_components = sum(distribution.values()) + num_classes = len(distribution.keys()) + return (num_components, num_classes, 2) + + def generate_end_hook(self, reasonings): + if not self.components_first: + reasonings = reasonings.permute(2, 1, 0) + return reasonings + + @abstractmethod + def generate(self, distribution: Union[dict, list, tuple]): + ... + return self.generate_end_hook(...) + + +class LiteralReasoningsInitializer(AbstractReasoningsInitializer): + """'Generate' the provided reasonings. + + Use this to 'generate' pre-initialized reasonings elsewhere. + + """ + def __init__(self, reasonings, **kwargs): + super().__init__(**kwargs) + self.reasonings = reasonings + + def generate(self, distribution: Union[dict, list, tuple]): + """Ignore `distributuion` and simply return self.reasonings.""" + reasonings = self.reasonings + if not isinstance(reasonings, torch.Tensor): + wmsg = f"Converting reasonings to {torch.Tensor}..." + warnings.warn(wmsg) + reasonings = torch.Tensor(reasonings) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +class ZerosReasoningsInitializer(AbstractReasoningsInitializer): + """Reasonings are all initialized with zeros.""" + def generate(self, distribution: Union[dict, list, tuple]): + shape = self.compute_shape(distribution) + reasonings = torch.zeros(*shape) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +class OnesReasoningsInitializer(AbstractReasoningsInitializer): + """Reasonings are all initialized with ones.""" + def generate(self, distribution: Union[dict, list, tuple]): + shape = self.compute_shape(distribution) + reasonings = torch.ones(*shape) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +class RandomReasoningsInitializer(AbstractReasoningsInitializer): + """Reasonings are randomly initialized.""" + def __init__(self, minimum=0.4, maximum=0.6, **kwargs): + super().__init__(**kwargs) + self.minimum = minimum + self.maximum = maximum + + def generate(self, distribution: Union[dict, list, tuple]): + shape = self.compute_shape(distribution) + reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): + """Each component reasons positively for exactly one class.""" + def generate(self, distribution: Union[dict, list, tuple]): + num_components, num_classes, _ = self.compute_shape(distribution) + A = OneHotLabelsInitializer().generate(distribution) + B = torch.zeros(num_components, num_classes) + reasonings = torch.stack([A, B], dim=-1) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +# Transforms +class AbstractTransformInitializer(ABC): + """Abstract class for all transform initializers.""" + ... + + +class AbstractLinearTransformInitializer(AbstractTransformInitializer): + """Abstract class for all linear transform initializers.""" + def __init__(self, out_dim_first: bool = False): + self.out_dim_first = out_dim_first + + def generate_end_hook(self, weights): + if self.out_dim_first: + weights = weights.permute(1, 0) + return weights + + @abstractmethod + def generate(self, in_dim: int, out_dim: int): + ... + return self.generate_end_hook(...) + + +class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer): + """Initialize a matrix with zeros.""" + def generate(self, in_dim: int, out_dim: int): + weights = torch.zeros(in_dim, out_dim) + return self.generate_end_hook(weights) + + +class OnesLinearTransformInitializer(AbstractLinearTransformInitializer): + """Initialize a matrix with ones.""" + def generate(self, in_dim: int, out_dim: int): + weights = torch.ones(in_dim, out_dim) + return self.generate_end_hook(weights) + + +class EyeTransformInitializer(AbstractLinearTransformInitializer): + """Initialize a matrix with the largest possible identity matrix.""" + def generate(self, in_dim: int, out_dim: int): + weights = torch.zeros(in_dim, out_dim) + I = torch.eye(min(in_dim, out_dim)) + weights[:I.shape[0], :I.shape[1]] = I + return self.generate_end_hook(weights) + + +class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): + """Abstract class for all data-aware linear transform initializers.""" + def __init__(self, + data: torch.Tensor, + noise: float = 0.0, + transform: Callable = torch.nn.Identity()): + self.data = data + self.noise = noise + self.transform = transform + + def generate_end_hook(self, weights: torch.Tensor): + drift = torch.rand_like(weights) * self.noise + weights = self.transform(weights + drift) + if self.out_dim_first: + weights = weights.permute(1, 0) + return weights + + +class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): + """Initialize a matrix with Eigenvectors from the data.""" + @abstractmethod + def generate(self, in_dim: int, out_dim: int): + _, _, weights = torch.pca_lowrank(self.data, q=out_dim) + return self.generate_end_hook(weights) + + +# Aliases - Components +CACI = ClassAwareCompInitializer +DACI = DataAwareCompInitializer +FVCI = FillValueCompInitializer +LCI = LiteralCompInitializer +MCI = MeanCompInitializer +OCI = OnesCompInitializer +RNCI = RandomNormalCompInitializer +SCI = SelectionCompInitializer +SMCI = StratifiedMeanCompInitializer +SSCI = StratifiedSelectionCompInitializer +UCI = UniformCompInitializer +ZCI = ZerosCompInitializer + +# Aliases - Labels +DLI = DataAwareLabelsInitializer +LI = LabelsInitializer +LLI = LiteralLabelsInitializer +OHLI = OneHotLabelsInitializer + +# Aliases - Reasonings +LRI = LiteralReasoningsInitializer +ORI = OnesReasoningsInitializer +PPRI = PurePositiveReasoningsInitializer +RRI = RandomReasoningsInitializer +ZRI = ZerosReasoningsInitializer + +# Aliases - Transforms +Eye = EyeTransformInitializer +OLTI = OnesLinearTransformInitializer +ZLTI = ZerosLinearTransformInitializer +PCALTI = PCALinearTransformInitializer diff --git a/prototorch/functions/losses.py b/prototorch/core/losses.py similarity index 58% rename from prototorch/functions/losses.py rename to prototorch/core/losses.py index 249882a..1a32103 100644 --- a/prototorch/functions/losses.py +++ b/prototorch/core/losses.py @@ -1,8 +1,11 @@ -"""ProtoTorch loss functions.""" +"""ProtoTorch losses""" import torch +from ..nn.activations import get_activation + +# Helpers def _get_matcher(targets, labels): """Returns a boolean tensor.""" matcher = torch.eq(targets.unsqueeze(dim=1), labels) @@ -28,6 +31,7 @@ def _get_dp_dm(distances, targets, plabels, with_indices=False): return dp.values, dm.values +# GLVQ def glvq_loss(distances, target_labels, prototype_labels): """GLVQ loss function with support for one-hot labels.""" dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) @@ -92,3 +96,76 @@ def rslvq_loss(probabilities, targets, prototype_labels): likelihood = correct / whole log_likelihood = torch.log(likelihood) return -1.0 * log_likelihood + + +def margin_loss(y_pred, y_true, margin=0.3): + """Compute the margin loss.""" + dp = torch.sum(y_true * y_pred, dim=-1) + dm = torch.max(y_pred - y_true, dim=-1).values + return torch.nn.functional.relu(dm - dp + margin) + + +class GLVQLoss(torch.nn.Module): + def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs): + super().__init__(**kwargs) + self.margin = margin + self.squashing = get_activation(squashing) + self.beta = torch.tensor(beta) + + def forward(self, outputs, targets): + distances, plabels = outputs + mu = glvq_loss(distances, targets, prototype_labels=plabels) + batch_loss = self.squashing(mu + self.margin, beta=self.beta) + return torch.sum(batch_loss, dim=0) + + +class MarginLoss(torch.nn.modules.loss._Loss): + def __init__(self, + margin=0.3, + size_average=None, + reduce=None, + reduction="mean"): + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, y_pred, y_true): + return margin_loss(y_pred, y_true, self.margin) + + +class NeuralGasEnergy(torch.nn.Module): + def __init__(self, lm, **kwargs): + super().__init__(**kwargs) + self.lm = lm + + def forward(self, d): + order = torch.argsort(d, dim=1) + ranks = torch.argsort(order, dim=1) + cost = torch.sum(self._nghood_fn(ranks, self.lm) * d) + + return cost, order + + def extra_repr(self): + return f"lambda: {self.lm}" + + @staticmethod + def _nghood_fn(rankings, lm): + return torch.exp(-rankings / lm) + + +class GrowingNeuralGasEnergy(NeuralGasEnergy): + def __init__(self, topology_layer, **kwargs): + super().__init__(**kwargs) + self.topology_layer = topology_layer + + @staticmethod + def _nghood_fn(rankings, topology): + winner = rankings[:, 0] + + weights = torch.zeros_like(rankings, dtype=torch.float) + weights[torch.arange(rankings.shape[0]), winner] = 1.0 + + neighbours = topology.get_neighbours(winner) + + weights[neighbours] = 0.1 + + return weights diff --git a/prototorch/functions/pooling.py b/prototorch/core/pooling.py similarity index 76% rename from prototorch/functions/pooling.py rename to prototorch/core/pooling.py index 6dd427e..fab143f 100644 --- a/prototorch/functions/pooling.py +++ b/prototorch/core/pooling.py @@ -1,4 +1,4 @@ -"""ProtoTorch pooling functions.""" +"""ProtoTorch pooling""" from typing import Callable @@ -78,3 +78,27 @@ def stratified_prod_pooling(values: torch.Tensor, fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(), fill_value=1.0) return winning_values + + +class StratifiedSumPooling(torch.nn.Module): + """Thin wrapper over the `stratified_sum_pooling` function.""" + def forward(self, values, labels): + return stratified_sum_pooling(values, labels) + + +class StratifiedProdPooling(torch.nn.Module): + """Thin wrapper over the `stratified_prod_pooling` function.""" + def forward(self, values, labels): + return stratified_prod_pooling(values, labels) + + +class StratifiedMinPooling(torch.nn.Module): + """Thin wrapper over the `stratified_min_pooling` function.""" + def forward(self, values, labels): + return stratified_min_pooling(values, labels) + + +class StratifiedMaxPooling(torch.nn.Module): + """Thin wrapper over the `stratified_max_pooling` function.""" + def forward(self, values, labels): + return stratified_max_pooling(values, labels) diff --git a/prototorch/functions/similarities.py b/prototorch/core/similarities.py similarity index 55% rename from prototorch/functions/similarities.py rename to prototorch/core/similarities.py index cc91c78..9929610 100644 --- a/prototorch/functions/similarities.py +++ b/prototorch/core/similarities.py @@ -1,7 +1,19 @@ -"""ProtoTorch similarity functions.""" +"""ProtoTorch similarities.""" import torch +from .distances import euclidean_distance + + +def gaussian(x, variance=1.0): + return torch.exp(-(x * x) / (2 * variance)) + + +def euclidean_similarity(x, y, variance=1.0): + distances = euclidean_distance(x, y) + similarities = gaussian(distances, variance) + return similarities + def cosine_similarity(x, y): """Compute the cosine similarity between :math:`x` and :math:`y`. @@ -9,6 +21,7 @@ def cosine_similarity(x, y): Expected dimension of x is 2. Expected dimension of y is 2. """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] norm_x = x.pow(2).sum(1).sqrt() norm_y = y.pow(2).sum(1).sqrt() norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T diff --git a/prototorch/core/transforms.py b/prototorch/core/transforms.py new file mode 100644 index 0000000..efac17c --- /dev/null +++ b/prototorch/core/transforms.py @@ -0,0 +1,43 @@ +"""ProtoTorch transforms""" + +import torch +from torch.nn.parameter import Parameter + +from .initializers import ( + AbstractLinearTransformInitializer, + EyeTransformInitializer, +) + + +class LinearTransform(torch.nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + initializer: + AbstractLinearTransformInitializer = EyeTransformInitializer()): + super().__init__() + self.set_weights(in_dim, out_dim, initializer) + + @property + def weights(self): + return self._weights.detach().cpu() + + def _register_weights(self, weights): + self.register_parameter("_weights", Parameter(weights)) + + def set_weights( + self, + in_dim: int, + out_dim: int, + initializer: + AbstractLinearTransformInitializer = EyeTransformInitializer()): + weights = initializer.generate(in_dim, out_dim) + self._register_weights(weights) + + def forward(self, x): + return x @ self.weights.T + + +# Aliases +Omega = LinearTransform diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index 1d61061..096fc6f 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -1,6 +1,12 @@ -"""ProtoTorch datasets.""" +"""ProtoTorch datasets""" from .abstract import NumpyDataset -from .sklearn import Blobs, Circles, Iris, Moons, Random +from .sklearn import ( + Blobs, + Circles, + Iris, + Moons, + Random, +) from .spiral import Spiral from .tecator import Tecator diff --git a/prototorch/datasets/abstract.py b/prototorch/datasets/abstract.py index e941c95..dac8f8c 100644 --- a/prototorch/datasets/abstract.py +++ b/prototorch/datasets/abstract.py @@ -1,10 +1,11 @@ -"""ProtoTorch abstract dataset classes. +"""ProtoTorch abstract dataset classes -Based on `torchvision.VisionDataset` and `torchvision.MNIST` +Based on `torchvision.VisionDataset` and `torchvision.MNIST`. For the original code, see: https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py + """ import os @@ -12,15 +13,6 @@ import os import torch -class NumpyDataset(torch.utils.data.TensorDataset): - """Create a PyTorch TensorDataset from NumPy arrays.""" - def __init__(self, data, targets): - self.data = torch.Tensor(data) - self.targets = torch.LongTensor(targets) - tensors = [self.data, self.targets] - super().__init__(*tensors) - - class Dataset(torch.utils.data.Dataset): """Abstract dataset class to be inherited.""" @@ -44,7 +36,7 @@ class ProtoDataset(Dataset): training_file = "training.pt" test_file = "test.pt" - def __init__(self, root, train=True, download=True, verbose=True): + def __init__(self, root="", train=True, download=True, verbose=True): super().__init__(root) self.train = train # training set or test set self.verbose = verbose @@ -96,3 +88,12 @@ class ProtoDataset(Dataset): def _download(self): raise NotImplementedError + + +class NumpyDataset(torch.utils.data.TensorDataset): + """Create a PyTorch TensorDataset from NumPy arrays.""" + def __init__(self, data, targets): + self.data = torch.Tensor(data) + self.targets = torch.LongTensor(targets) + tensors = [self.data, self.targets] + super().__init__(*tensors) diff --git a/prototorch/functions/__init__.py b/prototorch/functions/__init__.py deleted file mode 100644 index 9b3b993..0000000 --- a/prototorch/functions/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""ProtoTorch functions.""" - -from .activations import identity, sigmoid_beta, swish_beta -from .competitions import knnc, wtac -from .pooling import * diff --git a/prototorch/functions/competitions.py b/prototorch/functions/competitions.py deleted file mode 100644 index 326d510..0000000 --- a/prototorch/functions/competitions.py +++ /dev/null @@ -1,28 +0,0 @@ -"""ProtoTorch competition functions.""" - -import torch - - -def wtac(distances: torch.Tensor, - labels: torch.LongTensor) -> (torch.LongTensor): - """Winner-Takes-All-Competition. - - Returns the labels corresponding to the winners. - - """ - winning_indices = torch.min(distances, dim=1).indices - winning_labels = labels[winning_indices].squeeze() - return winning_labels - - -def knnc(distances: torch.Tensor, - labels: torch.LongTensor, - k: int = 1) -> (torch.LongTensor): - """K-Nearest-Neighbors-Competition. - - Returns the labels corresponding to the winners. - - """ - winning_indices = torch.topk(-distances, k=k, dim=1).indices - winning_labels = torch.mode(labels[winning_indices], dim=1).values - return winning_labels diff --git a/prototorch/functions/distances.py b/prototorch/functions/distances.py deleted file mode 100644 index 3a4f28f..0000000 --- a/prototorch/functions/distances.py +++ /dev/null @@ -1,259 +0,0 @@ -"""ProtoTorch distance functions.""" - -import numpy as np -import torch - -from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape, - equal_int_shape, get_flat) - - -def squared_euclidean_distance(x, y): - r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`. - - Compute :math:`{\langle \bm x - \bm y \rangle}_2` - - **Alias:** - ``prototorch.functions.distances.sed`` - """ - x, y = get_flat(x, y) - expanded_x = x.unsqueeze(dim=1) - batchwise_difference = y - expanded_x - differences_raised = torch.pow(batchwise_difference, 2) - distances = torch.sum(differences_raised, axis=2) - return distances - - -def euclidean_distance(x, y): - r"""Compute the Euclidean distance between :math:`x` and :math:`y`. - - Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}` - - :returns: Distance Tensor of shape :math:`X \times Y` - :rtype: `torch.tensor` - """ - x, y = get_flat(x, y) - distances_raised = squared_euclidean_distance(x, y) - distances = torch.sqrt(distances_raised) - return distances - - -def euclidean_distance_v2(x, y): - x, y = get_flat(x, y) - diff = y - x.unsqueeze(1) - pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt() - # Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the - # batch diagonal. See: - # https://pytorch.org/docs/stable/generated/torch.diagonal.html - distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1) - # print(f"{diff.shape=}") # (nx, ny, ndim) - # print(f"{pairwise_distances.shape=}") # (nx, ny, ny) - # print(f"{distances.shape=}") # (nx, ny) - return distances - - -def lpnorm_distance(x, y, p): - r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`. - Also known as Minkowski distance. - - Compute :math:`{\| \bm x - \bm y \|}_p`. - - Calls ``torch.cdist`` - - :param p: p parameter of the lp norm - """ - x, y = get_flat(x, y) - distances = torch.cdist(x, y, p=p) - return distances - - -def omega_distance(x, y, omega): - r"""Omega distance. - - Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p` - - :param `torch.tensor` omega: Two dimensional matrix - """ - x, y = get_flat(x, y) - projected_x = x @ omega - projected_y = y @ omega - distances = squared_euclidean_distance(projected_x, projected_y) - return distances - - -def lomega_distance(x, y, omegas): - r"""Localized Omega distance. - - Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p` - - :param `torch.tensor` omegas: Three dimensional matrix - """ - x, y = get_flat(x, y) - projected_x = x @ omegas - projected_y = torch.diagonal(y @ omegas).T - expanded_y = torch.unsqueeze(projected_y, dim=1) - batchwise_difference = expanded_y - projected_x - differences_squared = batchwise_difference**2 - distances = torch.sum(differences_squared, dim=2) - distances = distances.permute(1, 0) - return distances - - -def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10): - r"""Computes an euclidean distances matrix given two distinct vectors. - last dimension must be the vector dimension! - compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction! - - - ``x.shape = (number_of_x_vectors, vector_dim)`` - - ``y.shape = (number_of_y_vectors, vector_dim)`` - - output: matrix of distances (number_of_x_vectors, number_of_y_vectors) - """ - for tensor in [x, y]: - if tensor.ndim != 2: - raise ValueError( - "The tensor dimension must be two. You provide: tensor.ndim=" + - str(tensor.ndim) + ".") - if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]): - raise ValueError( - "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]=" - + str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" + - str(tuple(y.shape)[1]) + ".") - - y = torch.transpose(y) - - diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) + - torch.sum(y**2, axis=0, keepdims=True)) - - if not squared: - if epsilon == 0: - diss = torch.sqrt(diss) - else: - diss = torch.sqrt(torch.max(diss, epsilon)) - - return diss - - -def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10): - r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews - - For more info about Tangen distances see - - DOI:10.1109/IJCNN.2016.7727534. - - The subspaces is always assumed as transposed and must be orthogonal! - For local non sparse signals subspaces must be provided! - - - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN - - shape(protos): proto_number x dim1 x dim2 x ... x dimN - - shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape) - - subspace should be orthogonalized - Pytorch implementation of Sascha Saralajew's tensorflow code. - Translation by Christoph Raab - """ - signal_shape, signal_int_shape = _int_and_mixed_shape(signals) - proto_shape, proto_int_shape = _int_and_mixed_shape(protos) - subspace_int_shape = tuple(subspaces.shape) - - # check if the shapes are correct - _check_shapes(signal_int_shape, proto_int_shape) - - atom_axes = list(range(3, len(signal_int_shape))) - # for sparse signals, we use the memory efficient implementation - if signal_int_shape[1] == 1: - signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])]) - - if len(atom_axes) > 1: - protos = torch.reshape(protos, [proto_shape[0], -1]) - - if subspaces.ndim == 2: - # clean solution without map if the matrix_scope is global - projectors = torch.eye(subspace_int_shape[-2]) - torch.dot( - subspaces, torch.transpose(subspaces)) - - projected_signals = torch.dot(signals, projectors) - projected_protos = torch.dot(protos, projectors) - - diss = euclidean_distance_matrix(projected_signals, - projected_protos, - squared=squared, - epsilon=epsilon) - - diss = torch.reshape( - diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) - - return torch.permute(diss, [0, 2, 1]) - - else: - - # no solution without map possible --> memory efficient but slow! - projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm( - subspaces, - subspaces) # K.batch_dot(subspaces, subspaces, [2, 2]) - - projected_protos = (protos @ subspaces - ).T # K.batch_dot(projectors, protos, [1, 1])) - - def projected_norm(projector): - return torch.sum(torch.dot(signals, projector)**2, axis=1) - - diss = (torch.transpose(map(projected_norm, projectors)) - - 2 * torch.dot(signals, projected_protos) + - torch.sum(projected_protos**2, axis=0, keepdims=True)) - - if not squared: - if epsilon == 0: - diss = torch.sqrt(diss) - else: - diss = torch.sqrt(torch.max(diss, epsilon)) - - diss = torch.reshape( - diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) - - return torch.permute(diss, [0, 2, 1]) - - else: - signals = signals.permute([0, 2, 1] + atom_axes) - - diff = signals - protos - - # global tangent space - if subspaces.ndim == 2: - # Scope Projectors - projectors = subspaces # - - # Scope: Tangentspace Projections - diff = torch.reshape( - diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) - projected_diff = diff @ projectors - projected_diff = torch.reshape( - projected_diff, - (signal_shape[0], signal_shape[2], signal_shape[1]) + - signal_shape[3:], - ) - - diss = torch.norm(projected_diff, 2, dim=-1) - return diss.permute([0, 2, 1]) - - # local tangent spaces - else: - # Scope: Calculate Projectors - projectors = subspaces - - # Scope: Tangentspace Projections - diff = torch.reshape( - diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) - diff = diff.permute([1, 0, 2]) - projected_diff = torch.bmm(diff, projectors) - projected_diff = torch.reshape( - projected_diff, - (signal_shape[1], signal_shape[0], signal_shape[2]) + - signal_shape[3:], - ) - - diss = torch.norm(projected_diff, 2, dim=-1) - return diss.permute([1, 0, 2]).squeeze(-1) - - -# Aliases -sed = squared_euclidean_distance diff --git a/prototorch/functions/helper.py b/prototorch/functions/helper.py deleted file mode 100644 index ec95ba0..0000000 --- a/prototorch/functions/helper.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch - - -def get_flat(*args): - rv = [x.view(x.size(0), -1) for x in args] - return rv - - -def calculate_prototype_accuracy(y_pred, y_true, plabels): - """Computes the accuracy of a prototype based model. - via Winner-Takes-All rule. - Requirement: - y_pred.shape == y_true.shape - unique(y_pred) in plabels - """ - with torch.no_grad(): - idx = torch.argmin(y_pred, axis=1) - return torch.true_divide(torch.sum(y_true == plabels[idx]), - len(y_pred)) * 100 - - -def predict_label(y_pred, plabels): - r""" Predicts labels given a prediction of a prototype based model. - """ - with torch.no_grad(): - return plabels[torch.argmin(y_pred, 1)] - - -def mixed_shape(inputs): - if not torch.is_tensor(inputs): - raise ValueError("Input must be a tensor.") - else: - int_shape = list(inputs.shape) - # sometimes int_shape returns mixed integer types - int_shape = [int(i) if i is not None else i for i in int_shape] - tensor_shape = inputs.shape - - for i, s in enumerate(int_shape): - if s is None: - int_shape[i] = tensor_shape[i] - return tuple(int_shape) - - -def equal_int_shape(shape_1, shape_2): - if not isinstance(shape_1, - (tuple, list)) or not isinstance(shape_2, (tuple, list)): - raise ValueError("Input shapes must list or tuple.") - for shape in [shape_1, shape_2]: - if not all([isinstance(x, int) or x is None for x in shape]): - raise ValueError( - "Input shapes must be list or tuple of int and None values.") - - if len(shape_1) != len(shape_2): - return False - else: - for axis, value in enumerate(shape_1): - if value is not None and shape_2[axis] not in {value, None}: - return False - return True - - -def _check_shapes(signal_int_shape, proto_int_shape): - if len(signal_int_shape) < 4: - raise ValueError( - "The number of signal dimensions must be >=4. You provide: " + - str(len(signal_int_shape))) - - if len(proto_int_shape) < 2: - raise ValueError( - "The number of proto dimensions must be >=2. You provide: " + - str(len(proto_int_shape))) - - if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]): - raise ValueError( - "The atom shape of signals must be equal protos. You provide: signals.shape[3:]=" - + str(signal_int_shape[3:]) + " != protos.shape[1:]=" + - str(proto_int_shape[1:])) - - # not a sparse signal - if signal_int_shape[1] != 1: - if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]): - raise ValueError( - "If the signal is not sparse, the number of prototypes must be equal in signals and " - "protos. You provide: " + str(signal_int_shape[1]) + " != " + - str(proto_int_shape[0])) - - return True - - -def _int_and_mixed_shape(tensor): - shape = mixed_shape(tensor) - int_shape = tuple(i if isinstance(i, int) else None for i in shape) - - return shape, int_shape diff --git a/prototorch/functions/initializers.py b/prototorch/functions/initializers.py deleted file mode 100644 index 345b723..0000000 --- a/prototorch/functions/initializers.py +++ /dev/null @@ -1,107 +0,0 @@ -"""ProtoTorch initialization functions.""" - -from itertools import chain - -import torch - -INITIALIZERS = dict() - - -def register_initializer(function): - """Add the initializer to the registry.""" - INITIALIZERS[function.__name__] = function - return function - - -def labels_from(distribution, one_hot=True): - """Takes a distribution tensor and returns a labels tensor.""" - num_classes = distribution.shape[0] - llist = [[i] * n for i, n in zip(range(num_classes), distribution)] - # labels = [l for cl in llist for l in cl] # flatten the list of lists - flat_llist = list(chain(*llist)) # flatten label list with itertools.chain - plabels = torch.tensor(flat_llist, requires_grad=False) - if one_hot: - return torch.eye(num_classes)[plabels] - return plabels - - -@register_initializer -def ones(x_train, y_train, prototype_distribution, one_hot=True): - num_protos = torch.sum(prototype_distribution) - protos = torch.ones(num_protos, *x_train.shape[1:]) - plabels = labels_from(prototype_distribution, one_hot) - return protos, plabels - - -@register_initializer -def zeros(x_train, y_train, prototype_distribution, one_hot=True): - num_protos = torch.sum(prototype_distribution) - protos = torch.zeros(num_protos, *x_train.shape[1:]) - plabels = labels_from(prototype_distribution, one_hot) - return protos, plabels - - -@register_initializer -def rand(x_train, y_train, prototype_distribution, one_hot=True): - num_protos = torch.sum(prototype_distribution) - protos = torch.rand(num_protos, *x_train.shape[1:]) - plabels = labels_from(prototype_distribution, one_hot) - return protos, plabels - - -@register_initializer -def randn(x_train, y_train, prototype_distribution, one_hot=True): - num_protos = torch.sum(prototype_distribution) - protos = torch.randn(num_protos, *x_train.shape[1:]) - plabels = labels_from(prototype_distribution, one_hot) - return protos, plabels - - -@register_initializer -def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True): - num_protos = torch.sum(prototype_distribution) - pdim = x_train.shape[1] - protos = torch.empty(num_protos, pdim) - plabels = labels_from(prototype_distribution, one_hot) - for i, label in enumerate(plabels): - matcher = torch.eq(label.unsqueeze(dim=0), y_train) - if one_hot: - num_classes = y_train.size()[1] - matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) - xl = x_train[matcher] - mean_xl = torch.mean(xl, dim=0) - protos[i] = mean_xl - plabels = labels_from(prototype_distribution, one_hot=one_hot) - return protos, plabels - - -@register_initializer -def stratified_random(x_train, - y_train, - prototype_distribution, - one_hot=True, - epsilon=1e-7): - num_protos = torch.sum(prototype_distribution) - pdim = x_train.shape[1] - protos = torch.empty(num_protos, pdim) - plabels = labels_from(prototype_distribution, one_hot) - for i, label in enumerate(plabels): - matcher = torch.eq(label.unsqueeze(dim=0), y_train) - if one_hot: - num_classes = y_train.size()[1] - matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) - xl = x_train[matcher] - rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1) - random_xl = xl[rand_index] - protos[i] = random_xl + epsilon - plabels = labels_from(prototype_distribution, one_hot=one_hot) - return protos, plabels - - -def get_initializer(funcname): - """Deserialize the initializer.""" - if callable(funcname): - return funcname - if funcname in INITIALIZERS: - return INITIALIZERS.get(funcname) - raise NameError(f"Initializer {funcname} was not found.") diff --git a/prototorch/functions/normalization.py b/prototorch/functions/normalization.py deleted file mode 100644 index 96980b8..0000000 --- a/prototorch/functions/normalization.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import absolute_import, division, print_function - -import torch - - -def orthogonalization(tensors): - r""" Orthogonalization of a given tensor via polar decomposition. - """ - u, _, v = torch.svd(tensors, compute_uv=True) - u_shape = tuple(list(u.shape)) - v_shape = tuple(list(v.shape)) - - # reshape to (num x N x M) - u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1])) - v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1])) - - out = u @ v.permute([0, 2, 1]) - - out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], )) - - return out - - -def trace_normalization(tensors): - r""" Trace normalization - """ - epsilon = torch.tensor([1e-10], dtype=torch.float64) - # Scope trace_normalization - constant = torch.trace(tensors) - - if epsilon != 0: - constant = torch.max(constant, epsilon) - - return tensors / constant diff --git a/prototorch/functions/transforms.py b/prototorch/functions/transforms.py deleted file mode 100644 index 334d382..0000000 --- a/prototorch/functions/transforms.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch - - -# Functions -def gaussian(distances, variance): - return torch.exp(-(distances * distances) / (2 * variance)) - - -def rank_scaled_gaussian(distances, lambd): - order = torch.argsort(distances, dim=1) - ranks = torch.argsort(order, dim=1) - - return torch.exp(-torch.exp(-ranks / lambd) * distances) - - -# Modules -class GaussianPrior(torch.nn.Module): - def __init__(self, variance): - super().__init__() - self.variance = variance - - def forward(self, distances): - return gaussian(distances, self.variance) - - -class RankScaledGaussianPrior(torch.nn.Module): - def __init__(self, lambd): - super().__init__() - self.lambd = lambd - - def forward(self, distances): - return rank_scaled_gaussian(distances, self.lambd) diff --git a/prototorch/modules/__init__.py b/prototorch/modules/__init__.py deleted file mode 100644 index fc7ab87..0000000 --- a/prototorch/modules/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""ProtoTorch modules.""" - -from .competitions import * -from .pooling import * -from .wrappers import LambdaLayer, LossLayer diff --git a/prototorch/modules/competitions.py b/prototorch/modules/competitions.py deleted file mode 100644 index 53db40b..0000000 --- a/prototorch/modules/competitions.py +++ /dev/null @@ -1,42 +0,0 @@ -"""ProtoTorch Competition Modules.""" - -import torch - -from prototorch.functions.competitions import knnc, wtac - - -class WTAC(torch.nn.Module): - """Winner-Takes-All-Competition Layer. - - Thin wrapper over the `wtac` function. - - """ - def forward(self, distances, labels): - return wtac(distances, labels) - - -class LTAC(torch.nn.Module): - """Loser-Takes-All-Competition Layer. - - Thin wrapper over the `wtac` function. - - """ - def forward(self, probs, labels): - return wtac(-1.0 * probs, labels) - - -class KNNC(torch.nn.Module): - """K-Nearest-Neighbors-Competition. - - Thin wrapper over the `knnc` function. - - """ - def __init__(self, k=1, **kwargs): - super().__init__(**kwargs) - self.k = k - - def forward(self, distances, labels): - return knnc(distances, labels, k=self.k) - - def extra_repr(self): - return f"k: {self.k}" diff --git a/prototorch/modules/losses.py b/prototorch/modules/losses.py deleted file mode 100644 index d80ce15..0000000 --- a/prototorch/modules/losses.py +++ /dev/null @@ -1,59 +0,0 @@ -"""ProtoTorch losses.""" - -import torch - -from prototorch.functions.activations import get_activation -from prototorch.functions.losses import glvq_loss - - -class GLVQLoss(torch.nn.Module): - def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs): - super().__init__(**kwargs) - self.margin = margin - self.squashing = get_activation(squashing) - self.beta = torch.tensor(beta) - - def forward(self, outputs, targets): - distances, plabels = outputs - mu = glvq_loss(distances, targets, prototype_labels=plabels) - batch_loss = self.squashing(mu + self.margin, beta=self.beta) - return torch.sum(batch_loss, dim=0) - - -class NeuralGasEnergy(torch.nn.Module): - def __init__(self, lm, **kwargs): - super().__init__(**kwargs) - self.lm = lm - - def forward(self, d): - order = torch.argsort(d, dim=1) - ranks = torch.argsort(order, dim=1) - cost = torch.sum(self._nghood_fn(ranks, self.lm) * d) - - return cost, order - - def extra_repr(self): - return f"lambda: {self.lm}" - - @staticmethod - def _nghood_fn(rankings, lm): - return torch.exp(-rankings / lm) - - -class GrowingNeuralGasEnergy(NeuralGasEnergy): - def __init__(self, topology_layer, **kwargs): - super().__init__(**kwargs) - self.topology_layer = topology_layer - - @staticmethod - def _nghood_fn(rankings, topology): - winner = rankings[:, 0] - - weights = torch.zeros_like(rankings, dtype=torch.float) - weights[torch.arange(rankings.shape[0]), winner] = 1.0 - - neighbours = topology.get_neighbours(winner) - - weights[neighbours] = 0.1 - - return weights diff --git a/prototorch/modules/models.py b/prototorch/modules/models.py deleted file mode 100644 index 3c9b741..0000000 --- a/prototorch/modules/models.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch -from torch import nn - -from prototorch.components import LabeledComponents, StratifiedMeanInitializer -from prototorch.functions.distances import euclidean_distance_matrix -from prototorch.functions.normalization import orthogonalization - - -class GTLVQ(nn.Module): - r""" Generalized Tangent Learning Vector Quantization - - Parameters - ---------- - num_classes: int - Number of classes of the given classification problem. - - subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim) - Subspace data for the point approximation, required - - prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional) - prototype data for initalization of the prototypes used in GTLVQ. - - subspace_size: int (default=256,optional) - Subspace dimension of the Projectors. Currently only supported - with tagnent_projection_type=global. - - tangent_projection_type: string - Specifies the tangent projection type - options: local - local_proj - global - local: computes the tangent distances without emphasizing projected - data. Only distances are available - local_proj: computs tangent distances and returns the projected data - for further use. Be careful: data is repeated by number of prototypes - global: Number of subspaces is set to one and every prototypes - uses the same. - - prototypes_per_class: int (default=2,optional) - Number of prototypes per class - - feature_dim: int (default=256) - Dimensionality of the feature space specified as integer. - Prototype dimension. - - Notes - ----- - The GTLVQ [1] is a prototype-based classification learning model. The - GTLVQ uses the Tangent-Distances for a local point approximation - of an assumed data manifold via prototypial representations. - - The GTLVQ requires subspace projectors for transforming the data - and prototypes into the affine subspace. Every prototype is - equipped with a specific subpspace and represents a point - approximation of the assumed manifold. - - In practice prototypes and data are projected on this manifold - and pairwise euclidean distance computes. - - References - ---------- - .. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning - in classification based on manifolc. models and its relation - to tangent metric learning. In: 2017 International Joint - Conference on Neural Networks (IJCNN). - Bd. 2017-May : IEEE, 2017, S. 1756–1765 - """ - def __init__( - self, - num_classes, - subspace_data=None, - prototype_data=None, - subspace_size=256, - tangent_projection_type="local", - prototypes_per_class=2, - feature_dim=256, - ): - super(GTLVQ, self).__init__() - - self.num_protos = num_classes * prototypes_per_class - self.num_protos_class = prototypes_per_class - self.subspace_size = feature_dim if subspace_size is None else subspace_size - self.feature_dim = feature_dim - self.num_classes = num_classes - - cls_initializer = StratifiedMeanInitializer(prototype_data) - cls_distribution = { - "num_classes": num_classes, - "prototypes_per_class": prototypes_per_class, - } - - self.cls = LabeledComponents(cls_distribution, cls_initializer) - - if subspace_data is None: - raise ValueError("Init Data must be specified!") - - self.tpt = tangent_projection_type - with torch.no_grad(): - if self.tpt == "local": - self.init_local_subspace(subspace_data, subspace_size, - self.num_protos) - elif self.tpt == "global": - self.init_gobal_subspace(subspace_data, subspace_size) - else: - self.subspaces = None - - def forward(self, x): - if self.tpt == "local": - dis = self.local_tangent_distances(x) - elif self.tpt == "gloabl": - dis = self.global_tangent_distances(x) - else: - dis = (x @ self.cls.prototypes.T) / ( - torch.norm(x, dim=1, keepdim=True) @ torch.norm( - self.cls.prototypes, dim=1, keepdim=True).T) - return dis - - def init_gobal_subspace(self, data, num_subspaces): - _, _, v = torch.svd(data) - subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T - subspaces = subspace[:, :num_subspaces] - self.subspaces = nn.Parameter(subspaces, requires_grad=True) - - def init_local_subspace(self, data, num_subspaces, num_protos): - data = data - torch.mean(data, dim=0) - _, _, v = torch.svd(data, some=False) - v = v[:, :num_subspaces] - subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0) - self.subspaces = nn.Parameter(subspaces, requires_grad=True) - - def global_tangent_distances(self, x): - # Tangent Projection - x, projected_prototypes = ( - x @ self.subspaces, - self.cls.prototypes @ self.subspaces, - ) - # Euclidean Distance - return euclidean_distance_matrix(x, projected_prototypes) - - def local_tangent_distances(self, x): - - # Tangent Distance - x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components, - x.size(-1)) - protos = self.cls()[0].unsqueeze(0).expand(x.size(0), - self.cls.num_components, - x.size(-1)) - projectors = torch.eye( - self.subspaces.shape[-2], device=x.device) - torch.bmm( - self.subspaces, self.subspaces.permute([0, 2, 1])) - diff = (x - protos) - diff = diff.permute([1, 0, 2]) - diff = torch.bmm(diff, projectors) - diff = torch.norm(diff, 2, dim=-1).T - return diff - - def get_parameters(self): - return { - "params": self.cls.components, - }, { - "params": self.subspaces - } - - def orthogonalize_subspace(self): - if self.subspaces is not None: - with torch.no_grad(): - ortho_subpsaces = (orthogonalization(self.subspaces) - if self.tpt == "global" else - torch.nn.init.orthogonal_(self.subspaces)) - self.subspaces.copy_(ortho_subpsaces) diff --git a/prototorch/modules/pooling.py b/prototorch/modules/pooling.py deleted file mode 100644 index 7e57ffe..0000000 --- a/prototorch/modules/pooling.py +++ /dev/null @@ -1,32 +0,0 @@ -"""ProtoTorch Pooling Modules.""" - -import torch - -from prototorch.functions.pooling import (stratified_max_pooling, - stratified_min_pooling, - stratified_prod_pooling, - stratified_sum_pooling) - - -class StratifiedSumPooling(torch.nn.Module): - """Thin wrapper over the `stratified_sum_pooling` function.""" - def forward(self, values, labels): - return stratified_sum_pooling(values, labels) - - -class StratifiedProdPooling(torch.nn.Module): - """Thin wrapper over the `stratified_prod_pooling` function.""" - def forward(self, values, labels): - return stratified_prod_pooling(values, labels) - - -class StratifiedMinPooling(torch.nn.Module): - """Thin wrapper over the `stratified_min_pooling` function.""" - def forward(self, values, labels): - return stratified_min_pooling(values, labels) - - -class StratifiedMaxPooling(torch.nn.Module): - """Thin wrapper over the `stratified_max_pooling` function.""" - def forward(self, values, labels): - return stratified_max_pooling(values, labels) diff --git a/prototorch/nn/__init__.py b/prototorch/nn/__init__.py new file mode 100644 index 0000000..bf2445e --- /dev/null +++ b/prototorch/nn/__init__.py @@ -0,0 +1,4 @@ +"""ProtoTorch Neural Network Module""" + +from .activations import * +from .wrappers import * diff --git a/prototorch/functions/activations.py b/prototorch/nn/activations.py similarity index 79% rename from prototorch/functions/activations.py rename to prototorch/nn/activations.py index c5673ae..ab70762 100644 --- a/prototorch/functions/activations.py +++ b/prototorch/nn/activations.py @@ -1,4 +1,4 @@ -"""ProtoTorch activation functions.""" +"""ProtoTorch activations""" import torch @@ -57,6 +57,10 @@ def get_activation(funcname): """Deserialize the activation function.""" if callable(funcname): return funcname - if funcname in ACTIVATIONS: + elif funcname in ACTIVATIONS: return ACTIVATIONS.get(funcname) - raise NameError(f"Activation {funcname} was not found.") + else: + emsg = f"Unable to find matching function for `{funcname}` " \ + f"in `prototorch.nn.activations`. " + helpmsg = f"Possible values are {list(ACTIVATIONS.keys())}." + raise NameError(emsg + helpmsg) diff --git a/prototorch/modules/wrappers.py b/prototorch/nn/wrappers.py similarity index 97% rename from prototorch/modules/wrappers.py rename to prototorch/nn/wrappers.py index da94c52..c3fe781 100644 --- a/prototorch/modules/wrappers.py +++ b/prototorch/nn/wrappers.py @@ -1,4 +1,4 @@ -"""ProtoTorch Wrappers.""" +"""ProtoTorch wrappers.""" import torch diff --git a/prototorch/utils/__init__.py b/prototorch/utils/__init__.py index e69de29..26ccedd 100644 --- a/prototorch/utils/__init__.py +++ b/prototorch/utils/__init__.py @@ -0,0 +1,8 @@ +"""ProtoFlow utils module""" + +from .colors import hex_to_rgb, rgb_to_hex +from .utils import ( + mesh2d, + parse_data_arg, + parse_distribution, +) diff --git a/prototorch/utils/celluloid.py b/prototorch/utils/celluloid.py deleted file mode 100644 index 56eec36..0000000 --- a/prototorch/utils/celluloid.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid.""" - -from collections import defaultdict -from typing import Dict, List - -from matplotlib.animation import ArtistAnimation -from matplotlib.artist import Artist -from matplotlib.figure import Figure - -__version__ = "0.2.0" - - -class Camera: - """Make animations easier.""" - def __init__(self, figure: Figure) -> None: - """Create camera from matplotlib figure.""" - self._figure = figure - # need to keep track off artists for each axis - self._offsets: Dict[str, Dict[int, int]] = { - k: defaultdict(int) - for k in - ["collections", "patches", "lines", "texts", "artists", "images"] - } - self._photos: List[List[Artist]] = [] - - def snap(self) -> List[Artist]: - """Capture current state of the figure.""" - frame_artists: List[Artist] = [] - for i, axis in enumerate(self._figure.axes): - if axis.legend_ is not None: - axis.add_artist(axis.legend_) - for name in self._offsets: - new_artists = getattr(axis, name)[self._offsets[name][i]:] - frame_artists += new_artists - self._offsets[name][i] += len(new_artists) - self._photos.append(frame_artists) - return frame_artists - - def animate(self, *args, **kwargs) -> ArtistAnimation: - """Animate the snapshots taken. - Uses matplotlib.animation.ArtistAnimation - Returns - ------- - ArtistAnimation - """ - return ArtistAnimation(self._figure, self._photos, *args, **kwargs) diff --git a/prototorch/utils/colors.py b/prototorch/utils/colors.py index 65543e4..61ad1a0 100644 --- a/prototorch/utils/colors.py +++ b/prototorch/utils/colors.py @@ -1,78 +1,15 @@ -"""ProtoFlow color utilities.""" - -import matplotlib.lines as mlines -from matplotlib import cm -from matplotlib.colors import Normalize, to_hex, to_rgb +"""ProtoFlow color utilities""" -def color_scheme(n, - cmap="viridis", - form="hex", - tikz=False, - zero_indexed=False): - """Return *n* colors from the color scheme. - - Arguments: - n (int): number of colors to return - - Keyword Arguments: - cmap (str): Name of a matplotlib `colormap\ - `_. - form (str): Colorformat (supports "hex" and "rgb"). - tikz (bool): Output as `TikZ `_ - command. - zero_indexed (bool): Use zero indexing for output array. - - Returns: - (list): List of colors - """ - cmap = cm.get_cmap(cmap) - colornorm = Normalize(vmin=1, vmax=n) - hex_map = dict() - rgb_map = dict() - for cl in range(1, n + 1): - if zero_indexed: - hex_map[cl - 1] = to_hex(cmap(colornorm(cl))) - rgb_map[cl - 1] = to_rgb(cmap(colornorm(cl))) - else: - hex_map[cl] = to_hex(cmap(colornorm(cl))) - rgb_map[cl] = to_rgb(cmap(colornorm(cl))) - if tikz: - for k, v in rgb_map.items(): - print(f"\\definecolor{{color-{k}}}{{rgb}}{{{v[0]},{v[1]},{v[2]}}}") - if form == "hex": - return hex_map - elif form == "rgb": - return rgb_map - else: - return hex_map +def hex_to_rgb(hex_values): + for v in hex_values: + v = v.lstrip('#') + lv = len(v) + c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)] + yield c -def get_legend_handles(labels, marker="dots", zero_indexed=False): - """Return matplotlib legend handles and colors.""" - handles = list() - n = len(labels) - colors = color_scheme(n, - cmap="viridis", - form="hex", - zero_indexed=zero_indexed) - for label, color in zip(labels, colors.values()): - if marker == "dots": - handle = mlines.Line2D( - [], - [], - color="white", - markerfacecolor=color, - marker="o", - markersize=10, - markeredgecolor="k", - label=label, - ) - else: - handle = mlines.Line2D([], [], - color=color, - marker="", - markersize=15, - label=label) - handles.append(handle) - return handles, colors +def rgb_to_hex(rgb_values): + for v in rgb_values: + c = "%02x%02x%02x" % tuple(v) + yield c diff --git a/prototorch/utils/utils.py b/prototorch/utils/utils.py new file mode 100644 index 0000000..df718bb --- /dev/null +++ b/prototorch/utils/utils.py @@ -0,0 +1,104 @@ +"""ProtoFlow utilities""" + +import warnings +from collections.abc import Iterable +from typing import Union + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset + + +def mesh2d(x=None, border: float = 1.0, resolution: int = 100): + if x is not None: + x_shift = border * np.ptp(x[:, 0]) + y_shift = border * np.ptp(x[:, 1]) + x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift + y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift + else: + x_min, x_max = -border, border + y_min, y_max = -border, border + xx, yy = np.meshgrid(np.linspace(x_min, x_max, resolution), + np.linspace(y_min, y_max, resolution)) + mesh = np.c_[xx.ravel(), yy.ravel()] + return mesh, xx, yy + + +def distribution_from_list(list_dist: list[int], + clabels: Iterable[int] = None): + clabels = clabels or list(range(len(list_dist))) + distribution = dict(zip(clabels, list_dist)) + return distribution + + +def parse_distribution(user_distribution, + clabels: Iterable[int] = None) -> dict[int, int]: + """Parse user-provided distribution. + + Return a dictionary with integer keys that represent the class labels and + values that denote the number of components/prototypes with that class + label. + + The argument `user_distribution` could be any one of a number of allowed + formats. If it is a Python list, it is assumed that there are as many + entries in this list as there are classes, and the value at each index of + this list describes the number of prototypes for that particular class. So, + [1, 1, 1] implies that we have three classes with one prototype per class. + If it is a Python tuple, a shorthand of (num_classes, prototypes_per_class) + is assumed. If it is a Python dictionary, the key-value pairs describe the + class label and the number of prototypes for that class respectively. So, + {0: 2, 1: 2, 2: 2} implies that we have three classes with labels {1, 2, + 3}, each equipped with two prototypes. If however, the dictionary contains + the keys "num_classes" and "per_class", they are parsed to use their values + as one might expect. + + """ + if isinstance(user_distribution, dict): + if "num_classes" in user_distribution.keys(): + num_classes = int(user_distribution["num_classes"]) + per_class = int(user_distribution["per_class"]) + return distribution_from_list([per_class] * num_classes, clabels) + else: + return user_distribution + elif isinstance(user_distribution, tuple): + assert len(user_distribution) == 2 + num_classes, per_class = user_distribution + num_classes, per_class = int(num_classes), int(per_class) + return distribution_from_list([per_class] * num_classes, clabels) + elif isinstance(user_distribution, list): + return distribution_from_list(user_distribution, clabels) + else: + msg = f"`distribution` was not understood." \ + f"You have provided: {user_distribution}." + raise ValueError(msg) + + +def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]): + """Return data and target as torch tensors.""" + if isinstance(data_arg, Dataset): + if hasattr(data_arg, "__len__"): + ds_size = len(data_arg) # type: ignore + loader = DataLoader(data_arg, batch_size=ds_size) + data, targets = next(iter(loader)) + else: + emsg = f"Dataset {data_arg} is not sized (`__len__` unimplemented)." + raise TypeError(emsg) + + elif isinstance(data_arg, DataLoader): + data = torch.tensor([]) + targets = torch.tensor([]) + for x, y in data_arg: + data = torch.cat([data, x]) + targets = torch.cat([targets, y]) + else: + assert len(data_arg) == 2 + data, targets = data_arg + if not isinstance(data, torch.Tensor): + wmsg = f"Converting data to {torch.Tensor}..." + warnings.warn(wmsg) + data = torch.Tensor(data) + if not isinstance(targets, torch.LongTensor): + wmsg = f"Converting targets to {torch.LongTensor}..." + warnings.warn(wmsg) + targets = torch.LongTensor(targets) + return data, targets diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..33c1a02 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,15 @@ +[pylint] +disable = + too-many-arguments, + too-few-public-methods, + fixme, + +[pycodestyle] +max-line-length = 79 + +[isort] +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 3 +use_parentheses = True +line_length = 79 \ No newline at end of file diff --git a/setup.py b/setup.py index 44331f5..3cb2122 100644 --- a/setup.py +++ b/setup.py @@ -61,8 +61,9 @@ setup( license="MIT", install_requires=INSTALL_REQUIRES, extras_require={ - "docs": DOCS, "datasets": DATASETS, + "dev": DEV, + "docs": DOCS, "examples": EXAMPLES, "tests": TESTS, "all": ALL, diff --git a/tests/test_components.py b/tests/test_components.py deleted file mode 100644 index 5d95eb2..0000000 --- a/tests/test_components.py +++ /dev/null @@ -1,26 +0,0 @@ -"""ProtoTorch components test suite.""" - -import torch - -import prototorch as pt - - -def test_labcomps_zeros_init(): - protos = torch.zeros(3, 2) - c = pt.components.LabeledComponents( - distribution=[1, 1, 1], - initializer=pt.components.Zeros(2), - ) - assert (c.components == protos).any() == True - - -def test_labcomps_warmstart(): - protos = torch.randn(3, 2) - plabels = torch.tensor([1, 2, 3]) - c = pt.components.LabeledComponents( - distribution=[1, 1, 1], - initializer=None, - initialized_components=[protos, plabels], - ) - assert (c.components == protos).any() == True - assert (c.component_labels == plabels).any() == True diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..d007f9b --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,760 @@ +"""ProtoTorch core test suite""" + +import unittest + +import numpy as np +import pytest +import torch + +import prototorch as pt +from prototorch.utils import parse_distribution + + +# Utils +def test_parse_distribution_dict_0(): + distribution = {"num_classes": 1, "per_class": 0} + distribution = parse_distribution(distribution) + assert distribution == {0: 0} + + +def test_parse_distribution_dict_1(): + distribution = dict(num_classes=3, per_class=2) + distribution = parse_distribution(distribution) + assert distribution == {0: 2, 1: 2, 2: 2} + + +def test_parse_distribution_dict_2(): + distribution = {0: 1, 2: 2, -1: 3} + distribution = parse_distribution(distribution) + assert distribution == {0: 1, 2: 2, -1: 3} + + +def test_parse_distribution_tuple(): + distribution = (2, 3) + distribution = parse_distribution(distribution) + assert distribution == {0: 3, 1: 3} + + +def test_parse_distribution_list(): + distribution = [1, 1, 0, 2] + distribution = parse_distribution(distribution) + assert distribution == {0: 1, 1: 1, 2: 0, 3: 2} + + +def test_parse_distribution_custom_labels(): + distribution = [1, 1, 0, 2] + clabels = [1, 2, 5, 3] + distribution = parse_distribution(distribution, clabels) + assert distribution == {1: 1, 2: 1, 5: 0, 3: 2} + + +# Components initializers +def test_literal_comp_generate(): + protos = torch.rand(4, 3, 5, 5) + c = pt.initializers.LiteralCompInitializer(protos) + components = c.generate([]) + assert torch.allclose(components, protos) + + +def test_literal_comp_generate_from_list(): + protos = [[0, 1], [2, 3], [4, 5]] + c = pt.initializers.LiteralCompInitializer(protos) + with pytest.warns(UserWarning): + components = c.generate([]) + assert torch.allclose(components, torch.Tensor(protos)) + + +def test_shape_aware_raises_error(): + with pytest.raises(TypeError): + _ = pt.initializers.ShapeAwareCompInitializer(shape=(2, )) + + +def test_data_aware_comp_generate(): + protos = torch.rand(4, 3, 5, 5) + c = pt.initializers.DataAwareCompInitializer(protos) + components = c.generate(num_components="IgnoreMe!") + assert torch.allclose(components, protos) + + +def test_class_aware_comp_generate(): + protos = torch.rand(4, 2, 3, 5, 5) + plabels = torch.tensor([0, 0, 1, 1]).long() + c = pt.initializers.ClassAwareCompInitializer([protos, plabels]) + components = c.generate(distribution=[]) + assert torch.allclose(components, protos) + + +def test_zeros_comp_generate(): + shape = (3, 5, 5) + c = pt.initializers.ZerosCompInitializer(shape) + components = c.generate(num_components=4) + assert torch.allclose(components, torch.zeros(4, 3, 5, 5)) + + +def test_ones_comp_generate(): + c = pt.initializers.OnesCompInitializer(2) + components = c.generate(num_components=3) + assert torch.allclose(components, torch.ones(3, 2)) + + +def test_fill_value_comp_generate(): + c = pt.initializers.FillValueCompInitializer(2, 0.0) + components = c.generate(num_components=3) + assert torch.allclose(components, torch.zeros(3, 2)) + + +def test_uniform_comp_generate_min_max_bound(): + c = pt.initializers.UniformCompInitializer(2, -1.0, 1.0) + components = c.generate(num_components=1024) + assert components.min() >= -1.0 + assert components.max() <= 1.0 + + +def test_random_comp_generate_mean(): + c = pt.initializers.RandomNormalCompInitializer(2, -1.0) + components = c.generate(num_components=1024) + assert torch.allclose(components.mean(), + torch.tensor(-1.0), + rtol=1e-05, + atol=1e-01) + + +def test_comp_generate_0_components(): + c = pt.initializers.ZerosCompInitializer(2) + _ = c.generate(num_components=0) + + +def test_stratified_mean_comp_generate(): + # yapf: disable + x = torch.Tensor( + [[0, -1, -2], + [10, 11, 12], + [0, 0, 0], + [2, 2, 2]]) + y = torch.LongTensor([0, 0, 1, 1]) + desired = torch.Tensor( + [[5.0, 5.0, 5.0], + [1.0, 1.0, 1.0]]) + # yapf: enable + c = pt.initializers.StratifiedMeanCompInitializer(data=[x, y]) + actual = c.generate([1, 1]) + assert torch.allclose(actual, desired) + + +def test_stratified_selection_comp_generate(): + # yapf: disable + x = torch.Tensor( + [[0, 0, 0], + [1, 1, 1], + [0, 0, 0], + [1, 1, 1]]) + y = torch.LongTensor([0, 1, 0, 1]) + desired = torch.Tensor( + [[0, 0, 0], + [1, 1, 1]]) + # yapf: enable + c = pt.initializers.StratifiedSelectionCompInitializer(data=[x, y]) + actual = c.generate([1, 1]) + assert torch.allclose(actual, desired) + + +# Labels initializers +def test_literal_labels_init(): + l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2]) + with pytest.warns(UserWarning): + labels = l.generate([]) + assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2])) + + +def test_labels_init_from_list(): + l = pt.initializers.LabelsInitializer() + components = l.generate(distribution=[1, 1, 1]) + assert torch.allclose(components, torch.LongTensor([0, 1, 2])) + + +def test_labels_init_from_tuple_legal(): + l = pt.initializers.LabelsInitializer() + components = l.generate(distribution=(3, 1)) + assert torch.allclose(components, torch.LongTensor([0, 1, 2])) + + +def test_labels_init_from_tuple_illegal(): + l = pt.initializers.LabelsInitializer() + with pytest.raises(AssertionError): + _ = l.generate(distribution=(1, 1, 1)) + + +def test_data_aware_labels_init(): + data, targets = [0, 1, 2, 3], [0, 0, 1, 1] + ds = pt.datasets.NumpyDataset(data, targets) + l = pt.initializers.DataAwareLabelsInitializer(ds) + labels = l.generate([]) + assert torch.allclose(labels, torch.LongTensor(targets)) + + +# Reasonings initializers +def test_literal_reasonings_init(): + r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2]) + with pytest.warns(UserWarning): + reasonings = r.generate([]) + assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2])) + + +def test_random_reasonings_init(): + r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8) + reasonings = r.generate(distribution=[0, 1]) + assert torch.numel(reasonings) == 1 * 2 * 2 + assert reasonings.min() >= 0.2 + assert reasonings.max() <= 0.8 + + +def test_zeros_reasonings_init(): + r = pt.initializers.ZerosReasoningsInitializer() + reasonings = r.generate(distribution=[0, 1]) + assert torch.allclose(reasonings, torch.zeros(1, 2, 2)) + + +def test_ones_reasonings_init(): + r = pt.initializers.ZerosReasoningsInitializer() + reasonings = r.generate(distribution=[1, 2, 3]) + assert torch.allclose(reasonings, torch.zeros(6, 3, 2)) + + +def test_pure_positive_reasonings_init_one_per_class(): + r = pt.initializers.PurePositiveReasoningsInitializer( + components_first=False) + reasonings = r.generate(distribution=(4, 1)) + assert torch.allclose(reasonings[0], torch.eye(4)) + + +def test_pure_positive_reasonings_init_unrepresented_classes(): + r = pt.initializers.PurePositiveReasoningsInitializer() + reasonings = r.generate(distribution=[9, 0, 0, 0]) + assert reasonings.shape[0] == 9 + assert reasonings.shape[1] == 4 + assert reasonings.shape[2] == 2 + + +def test_random_reasonings_init_channels_not_first(): + r = pt.initializers.RandomReasoningsInitializer(components_first=False) + reasonings = r.generate(distribution=[0, 0, 0, 1]) + assert reasonings.shape[0] == 2 + assert reasonings.shape[1] == 4 + assert reasonings.shape[2] == 1 + + +# Transform initializers +def test_eye_transform_init_square(): + t = pt.initializers.EyeTransformInitializer() + I = t.generate(3, 3) + assert torch.allclose(I, torch.eye(3)) + + +def test_eye_transform_init_narrow(): + t = pt.initializers.EyeTransformInitializer() + actual = t.generate(3, 2) + desired = torch.Tensor([[1, 0], [0, 1], [0, 0]]) + assert torch.allclose(actual, desired) + + +def test_eye_transform_init_wide(): + t = pt.initializers.EyeTransformInitializer() + actual = t.generate(2, 3) + desired = torch.Tensor([[1, 0, 0], [0, 1, 0]]) + assert torch.allclose(actual, desired) + + +# Transforms +def test_linear_transform(): + l = pt.transforms.LinearTransform(2, 4) + actual = l.weights + desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]]) + assert torch.allclose(actual, desired) + + +def test_linear_transform_zeros_init(): + l = pt.transforms.LinearTransform( + in_dim=2, + out_dim=4, + initializer=pt.initializers.ZerosLinearTransformInitializer(), + ) + actual = l.weights + desired = torch.zeros(2, 4) + assert torch.allclose(actual, desired) + + +def test_linear_transform_out_dim_first(): + l = pt.transforms.LinearTransform( + in_dim=2, + out_dim=4, + initializer=pt.initializers.OLTI(out_dim_first=True), + ) + assert l.weights.shape[0] == 4 + assert l.weights.shape[1] == 2 + + +# Components +def test_components_no_initializer(): + with pytest.raises(TypeError): + _ = pt.components.Components(3, None) + + +def test_components_no_num_components(): + with pytest.raises(TypeError): + _ = pt.components.Components(initializer=pt.initializers.OCI(2)) + + +def test_components_none_num_components(): + with pytest.raises(TypeError): + _ = pt.components.Components(None, initializer=pt.initializers.OCI(2)) + + +def test_components_no_args(): + with pytest.raises(TypeError): + _ = pt.components.Components() + + +def test_components_zeros_init(): + c = pt.components.Components(3, pt.initializers.ZCI(2)) + assert torch.allclose(c.components, torch.zeros(3, 2)) + + +def test_labeled_components_dict_init(): + c = pt.components.LabeledComponents({0: 3}, pt.initializers.OCI(2)) + assert torch.allclose(c.components, torch.ones(3, 2)) + assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long)) + + +def test_labeled_components_list_init(): + c = pt.components.LabeledComponents([3], pt.initializers.OCI(2)) + assert torch.allclose(c.components, torch.ones(3, 2)) + assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long)) + + +def test_labeled_components_tuple_init(): + c = pt.components.LabeledComponents({0: 1, 1: 2}, pt.initializers.OCI(2)) + assert torch.allclose(c.components, torch.ones(3, 2)) + assert torch.allclose(c.labels, torch.LongTensor([0, 1, 1])) + + +# Labels +def test_standalone_labels_dict_init(): + l = pt.components.Labels({0: 3}) + assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long)) + + +def test_standalone_labels_list_init(): + l = pt.components.Labels([3]) + assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long)) + + +def test_standalone_labels_tuple_init(): + l = pt.components.Labels({0: 1, 1: 2}) + assert torch.allclose(l.labels, torch.LongTensor([0, 1, 1])) + + +# Losses +def test_glvq_loss_int_labels(): + d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) + labels = torch.tensor([0, 1]) + targets = torch.ones(100) + batch_loss = pt.losses.glvq_loss(distances=d, + target_labels=targets, + prototype_labels=labels) + loss_value = torch.sum(batch_loss, dim=0) + assert loss_value == -100 + + +def test_glvq_loss_one_hot_labels(): + d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) + labels = torch.tensor([[0, 1], [1, 0]]) + wl = torch.tensor([1, 0]) + targets = torch.stack([wl for _ in range(100)], dim=0) + batch_loss = pt.losses.glvq_loss(distances=d, + target_labels=targets, + prototype_labels=labels) + loss_value = torch.sum(batch_loss, dim=0) + assert loss_value == -100 + + +def test_glvq_loss_one_hot_unequal(): + dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)] + d = torch.stack(dlist, dim=1) + labels = torch.tensor([[0, 1], [1, 0], [1, 0]]) + wl = torch.tensor([1, 0]) + targets = torch.stack([wl for _ in range(100)], dim=0) + batch_loss = pt.losses.glvq_loss(distances=d, + target_labels=targets, + prototype_labels=labels) + loss_value = torch.sum(batch_loss, dim=0) + assert loss_value == -100 + + +# Activations +class TestActivations(unittest.TestCase): + def setUp(self): + self.flist = ["identity", "sigmoid_beta", "swish_beta"] + self.x = torch.randn(1024, 1) + + def test_registry(self): + self.assertIsNotNone(pt.nn.ACTIVATIONS) + + def test_funcname_deserialization(self): + for funcname in self.flist: + f = pt.nn.get_activation(funcname) + iscallable = callable(f) + self.assertTrue(iscallable) + + def test_callable_deserialization(self): + def dummy(x, **kwargs): + return x + + for f in [dummy, lambda x: x]: + f = pt.nn.get_activation(f) + iscallable = callable(f) + self.assertTrue(iscallable) + self.assertEqual(1, f(1)) + + def test_unknown_deserialization(self): + for funcname in ["blubb", "foobar"]: + with self.assertRaises(NameError): + _ = pt.nn.get_activation(funcname) + + def test_identity(self): + actual = pt.nn.identity(self.x) + desired = self.x + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_sigmoid_beta1(self): + actual = pt.nn.sigmoid_beta(self.x, beta=1.0) + desired = torch.sigmoid(self.x) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_swish_beta1(self): + actual = pt.nn.swish_beta(self.x, beta=1.0) + desired = self.x * torch.sigmoid(self.x) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def tearDown(self): + del self.x + + +# Competitions +class TestCompetitions(unittest.TestCase): + def setUp(self): + pass + + def test_wtac(self): + d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) + labels = torch.tensor([0, 1, 2, 3]) + competition_layer = pt.competitions.WTAC() + actual = competition_layer(d, labels) + desired = torch.tensor([2, 0]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_wtac_unequal_dist(self): + d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]]) + labels = torch.tensor([0, 1, 1]) + competition_layer = pt.competitions.WTAC() + actual = competition_layer(d, labels) + desired = torch.tensor([0, 1]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_wtac_one_hot(self): + d = torch.tensor([[1.99, 3.01], [3.0, 2.01]]) + labels = torch.tensor([[0, 1], [1, 0]]) + competition_layer = pt.competitions.WTAC() + actual = competition_layer(d, labels) + desired = torch.tensor([[0, 1], [1, 0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_knnc_k1(self): + d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) + labels = torch.tensor([0, 1, 2, 3]) + competition_layer = pt.competitions.KNNC(k=1) + actual = competition_layer(d, labels) + desired = torch.tensor([2, 0]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def tearDown(self): + pass + + +# Pooling +class TestPooling(unittest.TestCase): + def setUp(self): + pass + + def test_stratified_min(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) + labels = torch.tensor([0, 0, 1, 2]) + pooling_layer = pt.pooling.StratifiedMinPooling() + actual = pooling_layer(d, labels) + desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_min_one_hot(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) + labels = torch.tensor([0, 0, 1, 2]) + labels = torch.eye(3)[labels] + pooling_layer = pt.pooling.StratifiedMinPooling() + actual = pooling_layer(d, labels) + desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_min_trivial(self): + d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]]) + labels = torch.tensor([0, 1, 2]) + pooling_layer = pt.pooling.StratifiedMinPooling() + actual = pooling_layer(d, labels) + desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_max(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) + labels = torch.tensor([0, 0, 3, 2, 0]) + pooling_layer = pt.pooling.StratifiedMaxPooling() + actual = pooling_layer(d, labels) + desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_max_one_hot(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) + labels = torch.tensor([0, 0, 2, 1, 0]) + labels = torch.nn.functional.one_hot(labels, num_classes=3) + pooling_layer = pt.pooling.StratifiedMaxPooling() + actual = pooling_layer(d, labels) + desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_sum(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) + labels = torch.LongTensor([0, 0, 1, 2]) + pooling_layer = pt.pooling.StratifiedSumPooling() + actual = pooling_layer(d, labels) + desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_sum_one_hot(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) + labels = torch.tensor([0, 0, 1, 2]) + labels = torch.eye(3)[labels] + pooling_layer = pt.pooling.StratifiedSumPooling() + actual = pooling_layer(d, labels) + desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_prod(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) + labels = torch.tensor([0, 0, 3, 2, 0]) + pooling_layer = pt.pooling.StratifiedProdPooling() + actual = pooling_layer(d, labels) + desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def tearDown(self): + pass + + +# Distances +class TestDistances(unittest.TestCase): + def setUp(self): + self.nx, self.mx = 32, 2048 + self.ny, self.my = 8, 2048 + self.x = torch.randn(self.nx, self.mx) + self.y = torch.randn(self.ny, self.my) + + def test_manhattan(self): + actual = pt.distances.lpnorm_distance(self.x, self.y, p=1) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=1, + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=2) + self.assertIsNone(mismatch) + + def test_euclidean(self): + actual = pt.distances.euclidean_distance(self.x, self.y) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=2, + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=3) + self.assertIsNone(mismatch) + + def test_squared_euclidean(self): + actual = pt.distances.squared_euclidean_distance(self.x, self.y) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = (torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=2, + keepdim=False, + )**2) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=2) + self.assertIsNone(mismatch) + + def test_lpnorm_p0(self): + actual = pt.distances.lpnorm_distance(self.x, self.y, p=0) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=0, + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=4) + self.assertIsNone(mismatch) + + def test_lpnorm_p2(self): + actual = pt.distances.lpnorm_distance(self.x, self.y, p=2) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=2, + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=4) + self.assertIsNone(mismatch) + + def test_lpnorm_p3(self): + actual = pt.distances.lpnorm_distance(self.x, self.y, p=3) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=3, + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=4) + self.assertIsNone(mismatch) + + def test_lpnorm_pinf(self): + actual = pt.distances.lpnorm_distance(self.x, self.y, p=float("inf")) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=float("inf"), + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=4) + self.assertIsNone(mismatch) + + def test_omega_identity(self): + omega = torch.eye(self.mx, self.my) + actual = pt.distances.omega_distance(self.x, self.y, omega=omega) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = (torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=2, + keepdim=False, + )**2) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=2) + self.assertIsNone(mismatch) + + def test_lomega_identity(self): + omega = torch.eye(self.mx, self.my) + omegas = torch.stack([omega for _ in range(self.ny)], dim=0) + actual = pt.distances.lomega_distance(self.x, self.y, omegas=omegas) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = (torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=2, + keepdim=False, + )**2) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=2) + self.assertIsNone(mismatch) + + def tearDown(self): + del self.x, self.y diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8d109e3..f8c1aba 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,32 +1,97 @@ -"""ProtoTorch datasets test suite.""" +"""ProtoTorch datasets test suite""" import os import shutil import unittest +import numpy as np import torch -from prototorch.datasets import abstract, tecator +import prototorch as pt +from prototorch.datasets.abstract import Dataset, ProtoDataset class TestAbstract(unittest.TestCase): + def setUp(self): + self.ds = Dataset("./artifacts") + def test_getitem(self): with self.assertRaises(NotImplementedError): - abstract.Dataset("./artifacts")[0] + _ = self.ds[0] def test_len(self): with self.assertRaises(NotImplementedError): - len(abstract.Dataset("./artifacts")) + _ = len(self.ds) + + def tearDown(self): + del self.ds class TestProtoDataset(unittest.TestCase): - def test_getitem(self): - with self.assertRaises(NotImplementedError): - abstract.ProtoDataset("./artifacts")[0] - def test_download(self): with self.assertRaises(NotImplementedError): - abstract.ProtoDataset("./artifacts").download() + _ = ProtoDataset("./artifacts", download=True) + + def test_exists(self): + with self.assertRaises(RuntimeError): + _ = ProtoDataset("./artifacts", download=False) + + +class TestNumpyDataset(unittest.TestCase): + def test_list_init(self): + ds = pt.datasets.NumpyDataset([1], [1]) + self.assertEqual(len(ds), 1) + + def test_numpy_init(self): + data = np.random.randn(3, 2) + targets = np.array([0, 1, 2]) + ds = pt.datasets.NumpyDataset(data, targets) + self.assertEqual(len(ds), 3) + + +class TestSpiral(unittest.TestCase): + def test_init(self): + ds = pt.datasets.Spiral(num_samples=10) + self.assertEqual(len(ds), 10) + + +class TestIris(unittest.TestCase): + def setUp(self): + self.ds = pt.datasets.Iris() + + def test_size(self): + self.assertEqual(len(self.ds), 150) + + def test_dims(self): + self.assertEqual(self.ds.data.shape[1], 4) + + def test_dims_selection(self): + ds = pt.datasets.Iris(dims=[0, 1]) + self.assertEqual(ds.data.shape[1], 2) + + +class TestBlobs(unittest.TestCase): + def test_size(self): + ds = pt.datasets.Blobs(num_samples=10) + self.assertEqual(len(ds), 10) + + +class TestRandom(unittest.TestCase): + def test_size(self): + ds = pt.datasets.Random(num_samples=10) + self.assertEqual(len(ds), 10) + + +class TestCircles(unittest.TestCase): + def test_size(self): + ds = pt.datasets.Circles(num_samples=10) + self.assertEqual(len(ds), 10) + + +class TestMoons(unittest.TestCase): + def test_size(self): + ds = pt.datasets.Moons(num_samples=10) + self.assertEqual(len(ds), 10) class TestTecator(unittest.TestCase): @@ -42,25 +107,25 @@ class TestTecator(unittest.TestCase): rootdir = self.artifacts_dir.rpartition("/")[0] self._remove_artifacts() with self.assertRaises(RuntimeError): - _ = tecator.Tecator(rootdir, download=False) + _ = pt.datasets.Tecator(rootdir, download=False) def test_download_caching(self): rootdir = self.artifacts_dir.rpartition("/")[0] - _ = tecator.Tecator(rootdir, download=True, verbose=False) - _ = tecator.Tecator(rootdir, download=False, verbose=False) + _ = pt.datasets.Tecator(rootdir, download=True, verbose=False) + _ = pt.datasets.Tecator(rootdir, download=False, verbose=False) def test_repr(self): rootdir = self.artifacts_dir.rpartition("/")[0] - train = tecator.Tecator(rootdir, download=True, verbose=True) + train = pt.datasets.Tecator(rootdir, download=True, verbose=True) self.assertTrue("Split: Train" in train.__repr__()) def test_download_train(self): rootdir = self.artifacts_dir.rpartition("/")[0] - train = tecator.Tecator(root=rootdir, - train=True, - download=True, - verbose=False) - train = tecator.Tecator(root=rootdir, download=True, verbose=False) + train = pt.datasets.Tecator(root=rootdir, + train=True, + download=True, + verbose=False) + train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False) x_train, y_train = train.data, train.targets self.assertEqual(x_train.shape[0], 144) self.assertEqual(y_train.shape[0], 144) @@ -68,7 +133,7 @@ class TestTecator(unittest.TestCase): def test_download_test(self): rootdir = self.artifacts_dir.rpartition("/")[0] - test = tecator.Tecator(root=rootdir, train=False, verbose=False) + test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) x_test, y_test = test.data, test.targets self.assertEqual(x_test.shape[0], 71) self.assertEqual(y_test.shape[0], 71) @@ -76,20 +141,20 @@ class TestTecator(unittest.TestCase): def test_class_to_idx(self): rootdir = self.artifacts_dir.rpartition("/")[0] - test = tecator.Tecator(root=rootdir, train=False, verbose=False) + test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) _ = test.class_to_idx def test_getitem(self): rootdir = self.artifacts_dir.rpartition("/")[0] - test = tecator.Tecator(root=rootdir, train=False, verbose=False) + test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) x, y = test[0] self.assertEqual(x.shape[0], 100) self.assertIsInstance(y, int) def test_loadable_with_dataloader(self): rootdir = self.artifacts_dir.rpartition("/")[0] - test = tecator.Tecator(root=rootdir, train=False, verbose=False) + test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) _ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True) def tearDown(self): - pass + self._remove_artifacts() diff --git a/tests/test_functions.py b/tests/test_functions.py deleted file mode 100644 index cd049a8..0000000 --- a/tests/test_functions.py +++ /dev/null @@ -1,581 +0,0 @@ -"""ProtoTorch functions test suite.""" - -import unittest - -import numpy as np -import torch - -from prototorch.functions import (activations, competitions, distances, - initializers, losses, pooling) - - -class TestActivations(unittest.TestCase): - def setUp(self): - self.flist = ["identity", "sigmoid_beta", "swish_beta"] - self.x = torch.randn(1024, 1) - - def test_registry(self): - self.assertIsNotNone(activations.ACTIVATIONS) - - def test_funcname_deserialization(self): - for funcname in self.flist: - f = activations.get_activation(funcname) - iscallable = callable(f) - self.assertTrue(iscallable) - - # def test_torch_script(self): - # for funcname in self.flist: - # f = activations.get_activation(funcname) - # self.assertIsInstance(f, torch.jit.ScriptFunction) - - def test_callable_deserialization(self): - def dummy(x, **kwargs): - return x - - for f in [dummy, lambda x: x]: - f = activations.get_activation(f) - iscallable = callable(f) - self.assertTrue(iscallable) - self.assertEqual(1, f(1)) - - def test_unknown_deserialization(self): - for funcname in ["blubb", "foobar"]: - with self.assertRaises(NameError): - _ = activations.get_activation(funcname) - - def test_identity(self): - actual = activations.identity(self.x) - desired = self.x - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_sigmoid_beta1(self): - actual = activations.sigmoid_beta(self.x, beta=1.0) - desired = torch.sigmoid(self.x) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_swish_beta1(self): - actual = activations.swish_beta(self.x, beta=1.0) - desired = self.x * torch.sigmoid(self.x) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def tearDown(self): - del self.x - - -class TestCompetitions(unittest.TestCase): - def setUp(self): - pass - - def test_wtac(self): - d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) - labels = torch.tensor([0, 1, 2, 3]) - actual = competitions.wtac(d, labels) - desired = torch.tensor([2, 0]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_wtac_unequal_dist(self): - d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]]) - labels = torch.tensor([0, 1, 1]) - actual = competitions.wtac(d, labels) - desired = torch.tensor([0, 1]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_wtac_one_hot(self): - d = torch.tensor([[1.99, 3.01], [3.0, 2.01]]) - labels = torch.tensor([[0, 1], [1, 0]]) - actual = competitions.wtac(d, labels) - desired = torch.tensor([[0, 1], [1, 0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_knnc_k1(self): - d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) - labels = torch.tensor([0, 1, 2, 3]) - actual = competitions.knnc(d, labels, k=1) - desired = torch.tensor([2, 0]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def tearDown(self): - pass - - -class TestPooling(unittest.TestCase): - def setUp(self): - pass - - def test_stratified_min(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) - labels = torch.tensor([0, 0, 1, 2]) - actual = pooling.stratified_min_pooling(d, labels) - desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_min_one_hot(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) - labels = torch.tensor([0, 0, 1, 2]) - labels = torch.eye(3)[labels] - actual = pooling.stratified_min_pooling(d, labels) - desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_min_trivial(self): - d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]]) - labels = torch.tensor([0, 1, 2]) - actual = pooling.stratified_min_pooling(d, labels) - desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_max(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) - labels = torch.tensor([0, 0, 3, 2, 0]) - actual = pooling.stratified_max_pooling(d, labels) - desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_max_one_hot(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) - labels = torch.tensor([0, 0, 2, 1, 0]) - labels = torch.nn.functional.one_hot(labels, num_classes=3) - actual = pooling.stratified_max_pooling(d, labels) - desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_sum(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) - labels = torch.LongTensor([0, 0, 1, 2]) - actual = pooling.stratified_sum_pooling(d, labels) - desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_sum_one_hot(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) - labels = torch.tensor([0, 0, 1, 2]) - labels = torch.eye(3)[labels] - actual = pooling.stratified_sum_pooling(d, labels) - desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_prod(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) - labels = torch.tensor([0, 0, 3, 2, 0]) - actual = pooling.stratified_prod_pooling(d, labels) - desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def tearDown(self): - pass - - -class TestDistances(unittest.TestCase): - def setUp(self): - self.nx, self.mx = 32, 2048 - self.ny, self.my = 8, 2048 - self.x = torch.randn(self.nx, self.mx) - self.y = torch.randn(self.ny, self.my) - - def test_manhattan(self): - actual = distances.lpnorm_distance(self.x, self.y, p=1) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=1, - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=2) - self.assertIsNone(mismatch) - - def test_euclidean(self): - actual = distances.euclidean_distance(self.x, self.y) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=2, - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=3) - self.assertIsNone(mismatch) - - def test_squared_euclidean(self): - actual = distances.squared_euclidean_distance(self.x, self.y) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = (torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=2, - keepdim=False, - )**2) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=2) - self.assertIsNone(mismatch) - - def test_lpnorm_p0(self): - actual = distances.lpnorm_distance(self.x, self.y, p=0) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=0, - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=4) - self.assertIsNone(mismatch) - - def test_lpnorm_p2(self): - actual = distances.lpnorm_distance(self.x, self.y, p=2) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=2, - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=4) - self.assertIsNone(mismatch) - - def test_lpnorm_p3(self): - actual = distances.lpnorm_distance(self.x, self.y, p=3) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=3, - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=4) - self.assertIsNone(mismatch) - - def test_lpnorm_pinf(self): - actual = distances.lpnorm_distance(self.x, self.y, p=float("inf")) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=float("inf"), - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=4) - self.assertIsNone(mismatch) - - def test_omega_identity(self): - omega = torch.eye(self.mx, self.my) - actual = distances.omega_distance(self.x, self.y, omega=omega) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = (torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=2, - keepdim=False, - )**2) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=2) - self.assertIsNone(mismatch) - - def test_lomega_identity(self): - omega = torch.eye(self.mx, self.my) - omegas = torch.stack([omega for _ in range(self.ny)], dim=0) - actual = distances.lomega_distance(self.x, self.y, omegas=omegas) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = (torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=2, - keepdim=False, - )**2) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=2) - self.assertIsNone(mismatch) - - def tearDown(self): - del self.x, self.y - - -class TestInitializers(unittest.TestCase): - def setUp(self): - self.flist = [ - "zeros", - "ones", - "rand", - "randn", - "stratified_mean", - "stratified_random", - ] - self.x = torch.tensor( - [[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]], - dtype=torch.float32) - self.y = torch.tensor([0, 0, 1, 1]) - self.gen = torch.manual_seed(42) - - def test_registry(self): - self.assertIsNotNone(initializers.INITIALIZERS) - - def test_funcname_deserialization(self): - for funcname in self.flist: - f = initializers.get_initializer(funcname) - iscallable = callable(f) - self.assertTrue(iscallable) - - def test_callable_deserialization(self): - def dummy(x): - return x - - for f in [dummy, lambda x: x]: - f = initializers.get_initializer(f) - iscallable = callable(f) - self.assertTrue(iscallable) - self.assertEqual(1, f(1)) - - def test_unknown_deserialization(self): - for funcname in ["blubb", "foobar"]: - with self.assertRaises(NameError): - _ = initializers.get_initializer(funcname) - - def test_zeros(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.zeros(self.x, self.y, pdist) - desired = torch.zeros(2, 3) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_ones(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.ones(self.x, self.y, pdist) - desired = torch.ones(2, 3) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_rand(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.rand(self.x, self.y, pdist) - desired = torch.rand(2, 3, generator=torch.manual_seed(42)) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_randn(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.randn(self.x, self.y, pdist) - desired = torch.randn(2, 3, generator=torch.manual_seed(42)) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_mean_equal1(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False) - desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_random_equal1(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.stratified_random(self.x, self.y, pdist, - False) - desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_mean_equal2(self): - pdist = torch.tensor([2, 2]) - actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False) - desired = torch.tensor([[5.0, 5.0, 5.0], [5.0, 5.0, 5.0], - [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_random_equal2(self): - pdist = torch.tensor([2, 2]) - actual, _ = initializers.stratified_random(self.x, self.y, pdist, - False) - desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, -1.0, -2.0], - [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_mean_unequal(self): - pdist = torch.tensor([1, 3]) - actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False) - desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_random_unequal(self): - pdist = torch.tensor([1, 3]) - actual, _ = initializers.stratified_random(self.x, self.y, pdist, - False) - desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_mean_unequal_one_hot(self): - pdist = torch.tensor([1, 3]) - y = torch.eye(2)[self.y] - desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) - actual1, actual2 = initializers.stratified_mean(self.x, y, pdist) - desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]]) - mismatch = np.testing.assert_array_almost_equal(actual1, - desired1, - decimal=5) - mismatch = np.testing.assert_array_almost_equal(actual2, - desired2, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_random_unequal_one_hot(self): - pdist = torch.tensor([1, 3]) - y = torch.eye(2)[self.y] - actual1, actual2 = initializers.stratified_random(self.x, y, pdist) - desired1 = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) - desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]]) - mismatch = np.testing.assert_array_almost_equal(actual1, - desired1, - decimal=5) - mismatch = np.testing.assert_array_almost_equal(actual2, - desired2, - decimal=5) - self.assertIsNone(mismatch) - - def tearDown(self): - del self.x, self.y, self.gen - _ = torch.seed() - - -class TestLosses(unittest.TestCase): - def setUp(self): - pass - - def test_glvq_loss_int_labels(self): - d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) - labels = torch.tensor([0, 1]) - targets = torch.ones(100) - batch_loss = losses.glvq_loss(distances=d, - target_labels=targets, - prototype_labels=labels) - loss_value = torch.sum(batch_loss, dim=0) - self.assertEqual(loss_value, -100) - - def test_glvq_loss_one_hot_labels(self): - d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) - labels = torch.tensor([[0, 1], [1, 0]]) - wl = torch.tensor([1, 0]) - targets = torch.stack([wl for _ in range(100)], dim=0) - batch_loss = losses.glvq_loss(distances=d, - target_labels=targets, - prototype_labels=labels) - loss_value = torch.sum(batch_loss, dim=0) - self.assertEqual(loss_value, -100) - - def test_glvq_loss_one_hot_unequal(self): - dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)] - d = torch.stack(dlist, dim=1) - labels = torch.tensor([[0, 1], [1, 0], [1, 0]]) - wl = torch.tensor([1, 0]) - targets = torch.stack([wl for _ in range(100)], dim=0) - batch_loss = losses.glvq_loss(distances=d, - target_labels=targets, - prototype_labels=labels) - loss_value = torch.sum(batch_loss, dim=0) - self.assertEqual(loss_value, -100) - - def tearDown(self): - pass diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..e8a5e06 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,47 @@ +"""ProtoTorch utils test suite""" + +import numpy as np +import torch + +import prototorch as pt + + +def test_mesh2d_without_input(): + mesh, xx, yy = pt.utils.mesh2d(border=2.0, resolution=10) + assert mesh.shape[0] == 100 + assert mesh.shape[1] == 2 + assert xx.shape[0] == 10 + assert xx.shape[1] == 10 + assert yy.shape[0] == 10 + assert yy.shape[1] == 10 + assert np.min(xx) == -2.0 + assert np.max(xx) == 2.0 + assert np.min(yy) == -2.0 + assert np.max(yy) == 2.0 + + +def test_mesh2d_with_torch_input(): + x = 10 * torch.rand(5, 2) + mesh, xx, yy = pt.utils.mesh2d(x, border=0.0, resolution=100) + assert mesh.shape[0] == 100 * 100 + assert mesh.shape[1] == 2 + assert xx.shape[0] == 100 + assert xx.shape[1] == 100 + assert yy.shape[0] == 100 + assert yy.shape[1] == 100 + assert np.min(xx) == x[:, 0].min() + assert np.max(xx) == x[:, 0].max() + assert np.min(yy) == x[:, 1].min() + assert np.max(yy) == x[:, 1].max() + + +def test_hex_to_rgb(): + red_rgb = list(pt.utils.hex_to_rgb(["#ff0000"]))[0] + assert red_rgb[0] == 255 + assert red_rgb[1] == 0 + assert red_rgb[2] == 0 + + +def test_rgb_to_hex(): + blue_hex = list(pt.utils.rgb_to_hex([(0, 0, 255)]))[0] + assert blue_hex.lower() == "0000ff"