From 24903b761c7e63b5195238d564c17c3f37ffb03f Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 18:48:43 +0200 Subject: [PATCH 01/43] [WIP] Add labels.py --- prototorch/components/__init__.py | 5 +- prototorch/components/components.py | 27 +++++---- prototorch/components/initializers.py | 5 +- prototorch/components/labels.py | 86 +++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 15 deletions(-) create mode 100644 prototorch/components/labels.py diff --git a/prototorch/components/__init__.py b/prototorch/components/__init__.py index 07dd543..69293cb 100644 --- a/prototorch/components/__init__.py +++ b/prototorch/components/__init__.py @@ -1,2 +1,3 @@ -from prototorch.components.components import * -from prototorch.components.initializers import * +from .components import * +from .initializers import * +from .labels import * diff --git a/prototorch/components/components.py b/prototorch/components/components.py index 6d001f7..7ae1df6 100644 --- a/prototorch/components/components.py +++ b/prototorch/components/components.py @@ -1,4 +1,4 @@ -"""ProtoTorch components modules.""" +"""ProtoTorch Components.""" import warnings @@ -13,7 +13,7 @@ from torch.nn.parameter import Parameter from .initializers import parse_data_arg -def get_labels_object(distribution): +def get_labels_initializer(distribution): if isinstance(distribution, dict): if "num_classes" in distribution.keys(): labels = EqualLabelsInitializer( @@ -119,10 +119,11 @@ class LabeledComponents(Components): components, component_labels = parse_data_arg( initialized_components) super().__init__(initialized_components=components) + # self._labels = component_labels self._labels = component_labels else: - labels = get_labels_object(distribution) - self.initial_distribution = labels.distribution + labels_initializer = get_labels_initializer(distribution) + self.initial_distribution = labels_initializer.distribution _labels = labels.generate() super().__init__(len(_labels), initializer=initializer) self._register_labels(_labels) @@ -150,8 +151,8 @@ class LabeledComponents(Components): _precheck_initializer(initializer) # Labels - labels = get_labels_object(distribution) - new_labels = labels.generate() + labels_initializer = get_labels_initializer(distribution) + new_labels = labels_initializer.generate() _labels = torch.cat([self._labels, new_labels]) self._register_labels(_labels) @@ -196,20 +197,24 @@ class ReasoningComponents(Components): """ def __init__(self, - reasonings=None, + distribution=None, initializer=None, + reasoning_initializer=None, *, initialized_components=None): if initialized_components is not None: components, reasonings = initialized_components - super().__init__(initialized_components=components) self.register_parameter("_reasonings", reasonings) else: - self._initialize_reasonings(reasonings) - super().__init__(len(self._reasonings), initializer=initializer) + labels_initializer = get_labels_initializer(distribution) + self.initial_distribution = labels_initializer.distribution + super().__init__(len(self.initial_distribution), + initializer=initializer) + reasonings = reasoning_initializer.generate() + self._register_reasonings(reasonings) - def _initialize_reasonings(self, reasonings): + def _initialize_reasonings(self, reasoning_initializer): if isinstance(reasonings, tuple): num_classes, num_components = reasonings reasonings = ZeroReasoningsInitializer(num_classes, num_components) diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index d05c6c7..8839451 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -2,6 +2,7 @@ import warnings from collections.abc import Iterable from itertools import chain +from typing import List import torch from torch.utils.data import DataLoader, Dataset @@ -179,7 +180,7 @@ class UnequalLabelsInitializer(LabelsInitializer): self.clabels = clabels or range(len(self.dist)) @property - def distribution(self): + def distribution(self) -> List: return self.dist def generate(self): @@ -194,7 +195,7 @@ class EqualLabelsInitializer(LabelsInitializer): self.per_class = per_class @property - def distribution(self): + def distribution(self) -> List: return self.classes * [self.per_class] def generate(self): diff --git a/prototorch/components/labels.py b/prototorch/components/labels.py new file mode 100644 index 0000000..bf2620d --- /dev/null +++ b/prototorch/components/labels.py @@ -0,0 +1,86 @@ +"""ProtoTorch Labels.""" + +import torch +from prototorch.components.components import get_labels_initializer +from prototorch.components.initializers import (ClassAwareInitializer, + ComponentsInitializer, + EqualLabelsInitializer, + UnequalLabelsInitializer) +from torch.nn.parameter import Parameter + + +def get_labels_initializer(distribution): + if isinstance(distribution, dict): + if "num_classes" in distribution.keys(): + labels = EqualLabelsInitializer( + distribution["num_classes"], + distribution["prototypes_per_class"]) + else: + clabels = list(distribution.keys()) + dist = list(distribution.values()) + labels = UnequalLabelsInitializer(dist, clabels) + elif isinstance(distribution, tuple): + num_classes, prototypes_per_class = distribution + labels = EqualLabelsInitializer(num_classes, prototypes_per_class) + elif isinstance(distribution, list): + labels = UnequalLabelsInitializer(distribution) + else: + msg = f"`distribution` not understood." \ + f"You have provided: {distribution=}." + raise ValueError(msg) + return labels + + +class Labels(torch.nn.Module): + def __init__(self, + distribution=None, + initializer=None, + *, + initialized_labels=None): + _labels = self.get_labels(distribution, + initializer, + initialized_labels=initialized_labels) + self._register_labels(_labels) + + def _register_labels(self, labels): + # self.register_buffer("_labels", labels) + self.register_parameter("_labels", + Parameter(labels, requires_grad=False)) + + def get_labels(self, + distribution=None, + initializer=None, + *, + initialized_labels=None): + if initialized_labels is not None: + _labels = initialized_labels + else: + labels_initializer = initializer or get_labels_initializer( + distribution) + self.initial_distribution = labels_initializer.distribution + _labels = labels_initializer.generate() + return _labels + + def add_labels(self, + distribution=None, + initializer=None, + *, + initialized_labels=None): + new_labels = self.get_labels(distribution, + initializer, + initialized_labels=initialized_labels) + _labels = torch.cat([self._labels, new_labels]) + self._register_labels(_labels) + + def remove_labels(self, indices=None): + mask = torch.ones(len(self._labels, dtype=torch.bool)) + mask[indices] = False + _labels = self._labels[mask] + self._register_labels(_labels) + + @property + def labels(self): + return self._labels + + def forward(self): + return self._labels From 396d569351c80acd6a9a3bd90fd85c01cfff6601 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 23:07:07 +0200 Subject: [PATCH 02/43] Add utils.py --- prototorch/utils/utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 prototorch/utils/utils.py diff --git a/prototorch/utils/utils.py b/prototorch/utils/utils.py new file mode 100644 index 0000000..f48cb31 --- /dev/null +++ b/prototorch/utils/utils.py @@ -0,0 +1,18 @@ +"""ProtoFlow utilities""" + +import numpy as np + + +def mesh2d(x=None, border: float = 1.0, resolution: int = 100): + if x is not None: + x_shift = border * np.ptp(x[:, 0]) + y_shift = border * np.ptp(x[:, 1]) + x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift + y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift + else: + x_min, x_max = -border, border + y_min, y_max = -border, border + xx, yy = np.meshgrid(np.linspace(x_min, x_max, resolution), + np.linspace(y_min, y_max, resolution)) + mesh = np.c_[xx.ravel(), yy.ravel()] + return mesh, xx, yy From 56d554ed83ac56b4e5588b4ca4dfb16615400098 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 23:07:22 +0200 Subject: [PATCH 03/43] Remove celluloid.py --- prototorch/utils/celluloid.py | 46 ----------------------------------- 1 file changed, 46 deletions(-) delete mode 100644 prototorch/utils/celluloid.py diff --git a/prototorch/utils/celluloid.py b/prototorch/utils/celluloid.py deleted file mode 100644 index 56eec36..0000000 --- a/prototorch/utils/celluloid.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid.""" - -from collections import defaultdict -from typing import Dict, List - -from matplotlib.animation import ArtistAnimation -from matplotlib.artist import Artist -from matplotlib.figure import Figure - -__version__ = "0.2.0" - - -class Camera: - """Make animations easier.""" - def __init__(self, figure: Figure) -> None: - """Create camera from matplotlib figure.""" - self._figure = figure - # need to keep track off artists for each axis - self._offsets: Dict[str, Dict[int, int]] = { - k: defaultdict(int) - for k in - ["collections", "patches", "lines", "texts", "artists", "images"] - } - self._photos: List[List[Artist]] = [] - - def snap(self) -> List[Artist]: - """Capture current state of the figure.""" - frame_artists: List[Artist] = [] - for i, axis in enumerate(self._figure.axes): - if axis.legend_ is not None: - axis.add_artist(axis.legend_) - for name in self._offsets: - new_artists = getattr(axis, name)[self._offsets[name][i]:] - frame_artists += new_artists - self._offsets[name][i] += len(new_artists) - self._photos.append(frame_artists) - return frame_artists - - def animate(self, *args, **kwargs) -> ArtistAnimation: - """Animate the snapshots taken. - Uses matplotlib.animation.ArtistAnimation - Returns - ------- - ArtistAnimation - """ - return ArtistAnimation(self._figure, self._photos, *args, **kwargs) From 92b8d1785c7159894e5f195512e6073695b43314 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 23:07:55 +0200 Subject: [PATCH 04/43] Clean colors.py --- prototorch/utils/colors.py | 85 +++++--------------------------------- 1 file changed, 11 insertions(+), 74 deletions(-) diff --git a/prototorch/utils/colors.py b/prototorch/utils/colors.py index 65543e4..07e2d5d 100644 --- a/prototorch/utils/colors.py +++ b/prototorch/utils/colors.py @@ -1,78 +1,15 @@ """ProtoFlow color utilities.""" -import matplotlib.lines as mlines -from matplotlib import cm -from matplotlib.colors import Normalize, to_hex, to_rgb + +def hex_to_rgb(hex_values): + for v in hex_values: + v = v.lstrip('#') + lv = len(v) + c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)] + yield c -def color_scheme(n, - cmap="viridis", - form="hex", - tikz=False, - zero_indexed=False): - """Return *n* colors from the color scheme. - - Arguments: - n (int): number of colors to return - - Keyword Arguments: - cmap (str): Name of a matplotlib `colormap\ - `_. - form (str): Colorformat (supports "hex" and "rgb"). - tikz (bool): Output as `TikZ `_ - command. - zero_indexed (bool): Use zero indexing for output array. - - Returns: - (list): List of colors - """ - cmap = cm.get_cmap(cmap) - colornorm = Normalize(vmin=1, vmax=n) - hex_map = dict() - rgb_map = dict() - for cl in range(1, n + 1): - if zero_indexed: - hex_map[cl - 1] = to_hex(cmap(colornorm(cl))) - rgb_map[cl - 1] = to_rgb(cmap(colornorm(cl))) - else: - hex_map[cl] = to_hex(cmap(colornorm(cl))) - rgb_map[cl] = to_rgb(cmap(colornorm(cl))) - if tikz: - for k, v in rgb_map.items(): - print(f"\\definecolor{{color-{k}}}{{rgb}}{{{v[0]},{v[1]},{v[2]}}}") - if form == "hex": - return hex_map - elif form == "rgb": - return rgb_map - else: - return hex_map - - -def get_legend_handles(labels, marker="dots", zero_indexed=False): - """Return matplotlib legend handles and colors.""" - handles = list() - n = len(labels) - colors = color_scheme(n, - cmap="viridis", - form="hex", - zero_indexed=zero_indexed) - for label, color in zip(labels, colors.values()): - if marker == "dots": - handle = mlines.Line2D( - [], - [], - color="white", - markerfacecolor=color, - marker="o", - markersize=10, - markeredgecolor="k", - label=label, - ) - else: - handle = mlines.Line2D([], [], - color=color, - marker="", - markersize=15, - label=label) - handles.append(handle) - return handles, colors +def rgb_to_hex(rgb_values): + for v in rgb_values: + c = "%02x%02x%02x" % tuple(v) + yield c From abae72d6243f90197a2672f30b0f050582fcc3f3 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 23:08:12 +0200 Subject: [PATCH 05/43] Update utils module --- prototorch/utils/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/prototorch/utils/__init__.py b/prototorch/utils/__init__.py index e69de29..2f18133 100644 --- a/prototorch/utils/__init__.py +++ b/prototorch/utils/__init__.py @@ -0,0 +1,4 @@ +"""ProtoFlow utils module""" + +from .colors import hex_to_rgb, rgb_to_hex +from .utils import mesh2d From 0b2aaa42b8c8ee00ab5ba91f5dc7347887487e7e Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 23:08:32 +0200 Subject: [PATCH 06/43] Add utils test suite --- tests/test_utils.py | 47 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..e8a5e06 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,47 @@ +"""ProtoTorch utils test suite""" + +import numpy as np +import torch + +import prototorch as pt + + +def test_mesh2d_without_input(): + mesh, xx, yy = pt.utils.mesh2d(border=2.0, resolution=10) + assert mesh.shape[0] == 100 + assert mesh.shape[1] == 2 + assert xx.shape[0] == 10 + assert xx.shape[1] == 10 + assert yy.shape[0] == 10 + assert yy.shape[1] == 10 + assert np.min(xx) == -2.0 + assert np.max(xx) == 2.0 + assert np.min(yy) == -2.0 + assert np.max(yy) == 2.0 + + +def test_mesh2d_with_torch_input(): + x = 10 * torch.rand(5, 2) + mesh, xx, yy = pt.utils.mesh2d(x, border=0.0, resolution=100) + assert mesh.shape[0] == 100 * 100 + assert mesh.shape[1] == 2 + assert xx.shape[0] == 100 + assert xx.shape[1] == 100 + assert yy.shape[0] == 100 + assert yy.shape[1] == 100 + assert np.min(xx) == x[:, 0].min() + assert np.max(xx) == x[:, 0].max() + assert np.min(yy) == x[:, 1].min() + assert np.max(yy) == x[:, 1].max() + + +def test_hex_to_rgb(): + red_rgb = list(pt.utils.hex_to_rgb(["#ff0000"]))[0] + assert red_rgb[0] == 255 + assert red_rgb[1] == 0 + assert red_rgb[2] == 0 + + +def test_rgb_to_hex(): + blue_hex = list(pt.utils.rgb_to_hex([(0, 0, 255)]))[0] + assert blue_hex.lower() == "0000ff" From 44e47093875cf33101d5980c3f82b495a67f4a22 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 23:42:19 +0200 Subject: [PATCH 07/43] Minor aesthetic changes --- prototorch/datasets/__init__.py | 10 ++++++++-- prototorch/datasets/abstract.py | 25 +++++++++++++------------ prototorch/utils/colors.py | 2 +- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index 1d61061..096fc6f 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -1,6 +1,12 @@ -"""ProtoTorch datasets.""" +"""ProtoTorch datasets""" from .abstract import NumpyDataset -from .sklearn import Blobs, Circles, Iris, Moons, Random +from .sklearn import ( + Blobs, + Circles, + Iris, + Moons, + Random, +) from .spiral import Spiral from .tecator import Tecator diff --git a/prototorch/datasets/abstract.py b/prototorch/datasets/abstract.py index e941c95..dac8f8c 100644 --- a/prototorch/datasets/abstract.py +++ b/prototorch/datasets/abstract.py @@ -1,10 +1,11 @@ -"""ProtoTorch abstract dataset classes. +"""ProtoTorch abstract dataset classes -Based on `torchvision.VisionDataset` and `torchvision.MNIST` +Based on `torchvision.VisionDataset` and `torchvision.MNIST`. For the original code, see: https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py + """ import os @@ -12,15 +13,6 @@ import os import torch -class NumpyDataset(torch.utils.data.TensorDataset): - """Create a PyTorch TensorDataset from NumPy arrays.""" - def __init__(self, data, targets): - self.data = torch.Tensor(data) - self.targets = torch.LongTensor(targets) - tensors = [self.data, self.targets] - super().__init__(*tensors) - - class Dataset(torch.utils.data.Dataset): """Abstract dataset class to be inherited.""" @@ -44,7 +36,7 @@ class ProtoDataset(Dataset): training_file = "training.pt" test_file = "test.pt" - def __init__(self, root, train=True, download=True, verbose=True): + def __init__(self, root="", train=True, download=True, verbose=True): super().__init__(root) self.train = train # training set or test set self.verbose = verbose @@ -96,3 +88,12 @@ class ProtoDataset(Dataset): def _download(self): raise NotImplementedError + + +class NumpyDataset(torch.utils.data.TensorDataset): + """Create a PyTorch TensorDataset from NumPy arrays.""" + def __init__(self, data, targets): + self.data = torch.Tensor(data) + self.targets = torch.LongTensor(targets) + tensors = [self.data, self.targets] + super().__init__(*tensors) diff --git a/prototorch/utils/colors.py b/prototorch/utils/colors.py index 07e2d5d..61ad1a0 100644 --- a/prototorch/utils/colors.py +++ b/prototorch/utils/colors.py @@ -1,4 +1,4 @@ -"""ProtoFlow color utilities.""" +"""ProtoFlow color utilities""" def hex_to_rgb(hex_values): From 4a99bcbf0db50aa5a10632268c37484ff68ec6af Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 23:43:18 +0200 Subject: [PATCH 08/43] Update datasets test suite --- tests/test_datasets.py | 111 ++++++++++++++++++++++++++++++++--------- 1 file changed, 88 insertions(+), 23 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8d109e3..f8c1aba 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,32 +1,97 @@ -"""ProtoTorch datasets test suite.""" +"""ProtoTorch datasets test suite""" import os import shutil import unittest +import numpy as np import torch -from prototorch.datasets import abstract, tecator +import prototorch as pt +from prototorch.datasets.abstract import Dataset, ProtoDataset class TestAbstract(unittest.TestCase): + def setUp(self): + self.ds = Dataset("./artifacts") + def test_getitem(self): with self.assertRaises(NotImplementedError): - abstract.Dataset("./artifacts")[0] + _ = self.ds[0] def test_len(self): with self.assertRaises(NotImplementedError): - len(abstract.Dataset("./artifacts")) + _ = len(self.ds) + + def tearDown(self): + del self.ds class TestProtoDataset(unittest.TestCase): - def test_getitem(self): - with self.assertRaises(NotImplementedError): - abstract.ProtoDataset("./artifacts")[0] - def test_download(self): with self.assertRaises(NotImplementedError): - abstract.ProtoDataset("./artifacts").download() + _ = ProtoDataset("./artifacts", download=True) + + def test_exists(self): + with self.assertRaises(RuntimeError): + _ = ProtoDataset("./artifacts", download=False) + + +class TestNumpyDataset(unittest.TestCase): + def test_list_init(self): + ds = pt.datasets.NumpyDataset([1], [1]) + self.assertEqual(len(ds), 1) + + def test_numpy_init(self): + data = np.random.randn(3, 2) + targets = np.array([0, 1, 2]) + ds = pt.datasets.NumpyDataset(data, targets) + self.assertEqual(len(ds), 3) + + +class TestSpiral(unittest.TestCase): + def test_init(self): + ds = pt.datasets.Spiral(num_samples=10) + self.assertEqual(len(ds), 10) + + +class TestIris(unittest.TestCase): + def setUp(self): + self.ds = pt.datasets.Iris() + + def test_size(self): + self.assertEqual(len(self.ds), 150) + + def test_dims(self): + self.assertEqual(self.ds.data.shape[1], 4) + + def test_dims_selection(self): + ds = pt.datasets.Iris(dims=[0, 1]) + self.assertEqual(ds.data.shape[1], 2) + + +class TestBlobs(unittest.TestCase): + def test_size(self): + ds = pt.datasets.Blobs(num_samples=10) + self.assertEqual(len(ds), 10) + + +class TestRandom(unittest.TestCase): + def test_size(self): + ds = pt.datasets.Random(num_samples=10) + self.assertEqual(len(ds), 10) + + +class TestCircles(unittest.TestCase): + def test_size(self): + ds = pt.datasets.Circles(num_samples=10) + self.assertEqual(len(ds), 10) + + +class TestMoons(unittest.TestCase): + def test_size(self): + ds = pt.datasets.Moons(num_samples=10) + self.assertEqual(len(ds), 10) class TestTecator(unittest.TestCase): @@ -42,25 +107,25 @@ class TestTecator(unittest.TestCase): rootdir = self.artifacts_dir.rpartition("/")[0] self._remove_artifacts() with self.assertRaises(RuntimeError): - _ = tecator.Tecator(rootdir, download=False) + _ = pt.datasets.Tecator(rootdir, download=False) def test_download_caching(self): rootdir = self.artifacts_dir.rpartition("/")[0] - _ = tecator.Tecator(rootdir, download=True, verbose=False) - _ = tecator.Tecator(rootdir, download=False, verbose=False) + _ = pt.datasets.Tecator(rootdir, download=True, verbose=False) + _ = pt.datasets.Tecator(rootdir, download=False, verbose=False) def test_repr(self): rootdir = self.artifacts_dir.rpartition("/")[0] - train = tecator.Tecator(rootdir, download=True, verbose=True) + train = pt.datasets.Tecator(rootdir, download=True, verbose=True) self.assertTrue("Split: Train" in train.__repr__()) def test_download_train(self): rootdir = self.artifacts_dir.rpartition("/")[0] - train = tecator.Tecator(root=rootdir, - train=True, - download=True, - verbose=False) - train = tecator.Tecator(root=rootdir, download=True, verbose=False) + train = pt.datasets.Tecator(root=rootdir, + train=True, + download=True, + verbose=False) + train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False) x_train, y_train = train.data, train.targets self.assertEqual(x_train.shape[0], 144) self.assertEqual(y_train.shape[0], 144) @@ -68,7 +133,7 @@ class TestTecator(unittest.TestCase): def test_download_test(self): rootdir = self.artifacts_dir.rpartition("/")[0] - test = tecator.Tecator(root=rootdir, train=False, verbose=False) + test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) x_test, y_test = test.data, test.targets self.assertEqual(x_test.shape[0], 71) self.assertEqual(y_test.shape[0], 71) @@ -76,20 +141,20 @@ class TestTecator(unittest.TestCase): def test_class_to_idx(self): rootdir = self.artifacts_dir.rpartition("/")[0] - test = tecator.Tecator(root=rootdir, train=False, verbose=False) + test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) _ = test.class_to_idx def test_getitem(self): rootdir = self.artifacts_dir.rpartition("/")[0] - test = tecator.Tecator(root=rootdir, train=False, verbose=False) + test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) x, y = test[0] self.assertEqual(x.shape[0], 100) self.assertIsInstance(y, int) def test_loadable_with_dataloader(self): rootdir = self.artifacts_dir.rpartition("/")[0] - test = tecator.Tecator(root=rootdir, train=False, verbose=False) + test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) _ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True) def tearDown(self): - pass + self._remove_artifacts() From 5e72fd8187283c061bdef0c2d6de4e43c6dcc513 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 04:54:54 +0200 Subject: [PATCH 09/43] Remove test_components.py --- tests/test_components.py | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100644 tests/test_components.py diff --git a/tests/test_components.py b/tests/test_components.py deleted file mode 100644 index 03bc215..0000000 --- a/tests/test_components.py +++ /dev/null @@ -1,25 +0,0 @@ -"""ProtoTorch components test suite.""" - -import prototorch as pt -import torch - - -def test_labcomps_zeros_init(): - protos = torch.zeros(3, 2) - c = pt.components.LabeledComponents( - distribution=[1, 1, 1], - initializer=pt.components.Zeros(2), - ) - assert (c.components == protos).any() == True - - -def test_labcomps_warmstart(): - protos = torch.randn(3, 2) - plabels = torch.tensor([1, 2, 3]) - c = pt.components.LabeledComponents( - distribution=[1, 1, 1], - initializer=None, - initialized_components=[protos, plabels], - ) - assert (c.components == protos).any() == True - assert (c.component_labels == plabels).any() == True From dfefd128c4da7181858e80cc3e28fa905493138a Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 04:57:26 +0200 Subject: [PATCH 10/43] Update gitignore --- .gitignore | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 78ae7a0..0b72579 100644 --- a/.gitignore +++ b/.gitignore @@ -129,14 +129,6 @@ dmypy.json # End of https://www.gitignore.io/api/python -# ProtoFlow -core -checkpoint -logs/ -saved_weights/ -scratch* - - # Created by https://www.gitignore.io/api/visualstudiocode # Edit at https://www.gitignore.io/?templates=visualstudiocode @@ -154,5 +146,6 @@ scratch* # End of https://www.gitignore.io/api/visualstudiocode .vscode/ +# ProtoTorch artifacts reports artifacts \ No newline at end of file From b8969347b12cea2a4e89dfeee9e65c537e8ae330 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 04:58:11 +0200 Subject: [PATCH 11/43] Add more utils --- prototorch/utils/__init__.py | 6 ++- prototorch/utils/utils.py | 75 ++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/prototorch/utils/__init__.py b/prototorch/utils/__init__.py index 2f18133..26ccedd 100644 --- a/prototorch/utils/__init__.py +++ b/prototorch/utils/__init__.py @@ -1,4 +1,8 @@ """ProtoFlow utils module""" from .colors import hex_to_rgb, rgb_to_hex -from .utils import mesh2d +from .utils import ( + mesh2d, + parse_data_arg, + parse_distribution, +) diff --git a/prototorch/utils/utils.py b/prototorch/utils/utils.py index f48cb31..316d5eb 100644 --- a/prototorch/utils/utils.py +++ b/prototorch/utils/utils.py @@ -1,6 +1,11 @@ """ProtoFlow utilities""" +import warnings +from typing import Union + import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset def mesh2d(x=None, border: float = 1.0, resolution: int = 100): @@ -16,3 +21,73 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100): np.linspace(y_min, y_max, resolution)) mesh = np.c_[xx.ravel(), yy.ravel()] return mesh, xx, yy + + +def parse_distribution(user_distribution: Union[dict, list, tuple]): + """Parse user-provided distribution. + + Return a dictionary with integer keys that represent the class labels and + values that denote the number of components/prototypes with that class + label. + + The argument `user_distribution` could be any one of a number of allowed + formats. If it is a Python list, it is assumed that there are as many + entries in this list as there are classes, and the value at each index of + this list describes the number of prototypes for that particular class. So, + [1, 1, 1] implies that we have three classes with one prototype per class. + If it is a Python tuple, a shorthand of (num_classes, prototypes_per_class) + is assumed. If it is a Python dictionary, the key-value pairs describe the + class label and the number of prototypes for that class respectively. So, + {0: 2, 1: 2, 2: 2} implies that we have three classes with labels {1, 2, + 3}, each equipped with two prototypes. If however, the dictionary contains + the keys "num_classes" and "per_class", they are parsed to use their values + as one might expect. + + """ + def from_list(list_dist): + clabels = list(range(len(list_dist))) + distribution = dict(zip(clabels, list_dist)) + return distribution + + if isinstance(user_distribution, dict): + if "num_classes" in user_distribution.keys(): + num_classes = user_distribution["num_classes"] + per_class = user_distribution["per_class"] + return from_list([per_class] * num_classes) + else: + return user_distribution + elif isinstance(user_distribution, tuple): + assert len(user_distribution) == 2 + num_classes, per_class = user_distribution + return from_list([per_class] * num_classes) + elif isinstance(user_distribution, list): + return from_list(user_distribution) + else: + msg = f"`distribution` not understood." \ + f"You have provided: {user_distribution}." + raise ValueError(msg) + + +def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]): + if isinstance(data_arg, Dataset): + ds_size = len(data_arg) + data_arg = DataLoader(data_arg, batch_size=ds_size) + + if isinstance(data_arg, DataLoader): + data = torch.tensor([]) + targets = torch.tensor([]) + for x, y in data_arg: + data = torch.cat([data, x]) + targets = torch.cat([targets, y]) + else: + assert len(data_arg) == 2 + data, targets = data_arg + if not isinstance(data, torch.Tensor): + wmsg = f"Converting data to {torch.Tensor}." + warnings.warn(wmsg) + data = torch.Tensor(data) + if not isinstance(targets, torch.LongTensor): + wmsg = f"Converting targets to {torch.LongTensor}." + warnings.warn(wmsg) + targets = torch.LongTensor(targets) + return data, targets From 5dddb39ec4c233d491395a6104eb9d2f7d9017e1 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 20:29:24 +0200 Subject: [PATCH 12/43] [REFACTOR] Clean and move components and initializers into core --- prototorch/components/components.py | 235 ------------------ prototorch/components/initializers.py | 225 ----------------- prototorch/components/labels.py | 86 ------- prototorch/{components => core}/__init__.py | 2 + prototorch/core/components.py | 243 ++++++++++++++++++ prototorch/core/initializers.py | 258 ++++++++++++++++++++ prototorch/utils/utils.py | 10 +- 7 files changed, 510 insertions(+), 549 deletions(-) delete mode 100644 prototorch/components/components.py delete mode 100644 prototorch/components/initializers.py delete mode 100644 prototorch/components/labels.py rename prototorch/{components => core}/__init__.py (76%) create mode 100644 prototorch/core/components.py create mode 100644 prototorch/core/initializers.py diff --git a/prototorch/components/components.py b/prototorch/components/components.py deleted file mode 100644 index 7ae1df6..0000000 --- a/prototorch/components/components.py +++ /dev/null @@ -1,235 +0,0 @@ -"""ProtoTorch Components.""" - -import warnings - -import torch -from prototorch.components.initializers import (ClassAwareInitializer, - ComponentsInitializer, - EqualLabelsInitializer, - UnequalLabelsInitializer, - ZeroReasoningsInitializer) -from torch.nn.parameter import Parameter - -from .initializers import parse_data_arg - - -def get_labels_initializer(distribution): - if isinstance(distribution, dict): - if "num_classes" in distribution.keys(): - labels = EqualLabelsInitializer( - distribution["num_classes"], - distribution["prototypes_per_class"]) - else: - clabels = list(distribution.keys()) - dist = list(distribution.values()) - labels = UnequalLabelsInitializer(dist, clabels) - elif isinstance(distribution, tuple): - num_classes, prototypes_per_class = distribution - labels = EqualLabelsInitializer(num_classes, prototypes_per_class) - elif isinstance(distribution, list): - labels = UnequalLabelsInitializer(distribution) - else: - msg = f"`distribution` not understood." \ - f"You have provided: {distribution=}." - raise ValueError(msg) - return labels - - -def _precheck_initializer(initializer): - if not isinstance(initializer, ComponentsInitializer): - emsg = f"`initializer` has to be some subtype of " \ - f"{ComponentsInitializer}. " \ - f"You have provided: {initializer=} instead." - raise TypeError(emsg) - - -class Components(torch.nn.Module): - """Components is a set of learnable Tensors.""" - def __init__(self, - num_components=None, - initializer=None, - *, - initialized_components=None): - super().__init__() - - # Ignore all initialization settings if initialized_components is given. - if initialized_components is not None: - self._register_components(initialized_components) - if num_components is not None or initializer is not None: - wmsg = "Arguments ignored while initializing Components" - warnings.warn(wmsg) - else: - self._initialize_components(num_components, initializer) - - @property - def num_components(self): - return len(self._components) - - def _register_components(self, components): - self.register_parameter("_components", Parameter(components)) - - def _initialize_components(self, num_components, initializer): - _precheck_initializer(initializer) - _components = initializer.generate(num_components) - self._register_components(_components) - - def add_components(self, - num=1, - initializer=None, - *, - initialized_components=None): - if initialized_components is not None: - _components = torch.cat([self._components, initialized_components]) - else: - _precheck_initializer(initializer) - _new = initializer.generate(num) - _components = torch.cat([self._components, _new]) - self._register_components(_components) - - def remove_components(self, indices=None): - mask = torch.ones(self.num_components, dtype=torch.bool) - mask[indices] = False - _components = self._components[mask] - self._register_components(_components) - return mask - - @property - def components(self): - """Tensor containing the component tensors.""" - return self._components.detach() - - def forward(self): - return self._components - - def extra_repr(self): - return f"(components): (shape: {tuple(self._components.shape)})" - - -class LabeledComponents(Components): - """LabeledComponents generate a set of components and a set of labels. - - Every Component has a label assigned. - """ - def __init__(self, - distribution=None, - initializer=None, - *, - initialized_components=None): - if initialized_components is not None: - components, component_labels = parse_data_arg( - initialized_components) - super().__init__(initialized_components=components) - # self._labels = component_labels - self._labels = component_labels - else: - labels_initializer = get_labels_initializer(distribution) - self.initial_distribution = labels_initializer.distribution - _labels = labels.generate() - super().__init__(len(_labels), initializer=initializer) - self._register_labels(_labels) - - def _register_labels(self, labels): - self.register_buffer("_labels", labels) - - @property - def distribution(self): - clabels, counts = torch.unique(self._labels, - sorted=True, - return_counts=True) - return dict(zip(clabels.tolist(), counts.tolist())) - - def _initialize_components(self, num_components, initializer): - if isinstance(initializer, ClassAwareInitializer): - _precheck_initializer(initializer) - _components = initializer.generate(num_components, - self.initial_distribution) - self._register_components(_components) - else: - super()._initialize_components(num_components, initializer) - - def add_components(self, distribution, initializer): - _precheck_initializer(initializer) - - # Labels - labels_initializer = get_labels_initializer(distribution) - new_labels = labels_initializer.generate() - _labels = torch.cat([self._labels, new_labels]) - self._register_labels(_labels) - - # Components - if isinstance(initializer, ClassAwareInitializer): - _new = initializer.generate(len(new_labels), distribution) - else: - _new = initializer.generate(len(new_labels)) - _components = torch.cat([self._components, _new]) - self._register_components(_components) - - def remove_components(self, indices=None): - # Components - mask = super().remove_components(indices) - - # Labels - _labels = self._labels[mask] - self._register_labels(_labels) - - @property - def component_labels(self): - """Tensor containing the component tensors.""" - return self._labels.detach() - - def forward(self): - return super().forward(), self._labels - - -class ReasoningComponents(Components): - """ReasoningComponents generate a set of components and a set of reasoning matrices. - - Every Component has a reasoning matrix assigned. - - A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The - first element is called positive reasoning :math:`p`, the second negative - reasoning :math:`n`. A components can reason in favour (positive) of a - class, against (negative) a class or not at all (neutral). - - It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 - \leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a - three element probability distribution. - - """ - def __init__(self, - distribution=None, - initializer=None, - reasoning_initializer=None, - *, - initialized_components=None): - if initialized_components is not None: - components, reasonings = initialized_components - super().__init__(initialized_components=components) - self.register_parameter("_reasonings", reasonings) - else: - labels_initializer = get_labels_initializer(distribution) - self.initial_distribution = labels_initializer.distribution - super().__init__(len(self.initial_distribution), - initializer=initializer) - reasonings = reasoning_initializer.generate() - self._register_reasonings(reasonings) - - def _initialize_reasonings(self, reasoning_initializer): - if isinstance(reasonings, tuple): - num_classes, num_components = reasonings - reasonings = ZeroReasoningsInitializer(num_classes, num_components) - - _reasonings = reasonings.generate() - self.register_parameter("_reasonings", _reasonings) - - @property - def reasonings(self): - """Returns Reasoning Matrix. - - Dimension NxCx2 - - """ - return self._reasonings.detach() - - def forward(self): - return super().forward(), self._reasonings diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py deleted file mode 100644 index 8839451..0000000 --- a/prototorch/components/initializers.py +++ /dev/null @@ -1,225 +0,0 @@ -"""ProtoTroch Initializers.""" -import warnings -from collections.abc import Iterable -from itertools import chain -from typing import List - -import torch -from torch.utils.data import DataLoader, Dataset - - -def parse_data_arg(data_arg): - if isinstance(data_arg, Dataset): - data_arg = DataLoader(data_arg, batch_size=len(data_arg)) - - if isinstance(data_arg, DataLoader): - data = torch.tensor([]) - targets = torch.tensor([]) - for x, y in data_arg: - data = torch.cat([data, x]) - targets = torch.cat([targets, y]) - else: - data, targets = data_arg - if not isinstance(data, torch.Tensor): - wmsg = f"Converting data to {torch.Tensor}." - warnings.warn(wmsg) - data = torch.Tensor(data) - if not isinstance(targets, torch.Tensor): - wmsg = f"Converting targets to {torch.Tensor}." - warnings.warn(wmsg) - targets = torch.Tensor(targets) - return data, targets - - -def get_subinitializers(data, targets, clabels, subinit_type): - initializers = dict() - for clabel in clabels: - class_data = data[targets == clabel] - class_initializer = subinit_type(class_data) - initializers[clabel] = (class_initializer) - return initializers - - -# Components -class ComponentsInitializer(object): - def generate(self, number_of_components): - raise NotImplementedError("Subclasses should implement this!") - - -class DimensionAwareInitializer(ComponentsInitializer): - def __init__(self, dims): - super().__init__() - if isinstance(dims, Iterable): - self.components_dims = tuple(dims) - else: - self.components_dims = (dims, ) - - -class OnesInitializer(DimensionAwareInitializer): - def __init__(self, dims, scale=1.0): - super().__init__(dims) - self.scale = scale - - def generate(self, length): - gen_dims = (length, ) + self.components_dims - return torch.ones(gen_dims) * self.scale - - -class ZerosInitializer(DimensionAwareInitializer): - def generate(self, length): - gen_dims = (length, ) + self.components_dims - return torch.zeros(gen_dims) - - -class UniformInitializer(DimensionAwareInitializer): - def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0): - super().__init__(dims) - self.minimum = minimum - self.maximum = maximum - self.scale = scale - - def generate(self, length): - gen_dims = (length, ) + self.components_dims - return torch.ones(gen_dims).uniform_(self.minimum, - self.maximum) * self.scale - - -class DataAwareInitializer(ComponentsInitializer): - def __init__(self, data, transform=torch.nn.Identity()): - super().__init__() - self.data = data - self.transform = transform - - def __del__(self): - del self.data - - -class SelectionInitializer(DataAwareInitializer): - def generate(self, length): - indices = torch.LongTensor(length).random_(0, len(self.data)) - return self.transform(self.data[indices]) - - -class MeanInitializer(DataAwareInitializer): - def generate(self, length): - mean = torch.mean(self.data, dim=0) - repeat_dim = [length] + [1] * len(mean.shape) - return self.transform(mean.repeat(repeat_dim)) - - -class ClassAwareInitializer(DataAwareInitializer): - def __init__(self, data, transform=torch.nn.Identity()): - data, targets = parse_data_arg(data) - super().__init__(data, transform) - self.targets = targets - self.clabels = torch.unique(self.targets).int().tolist() - self.num_classes = len(self.clabels) - - def _get_samples_from_initializer(self, length, dist): - if not dist: - per_class = length // self.num_classes - dist = dict(zip(self.clabels, self.num_classes * [per_class])) - if isinstance(dist, list): - dist = dict(zip(self.clabels, dist)) - samples = [self.initializers[k].generate(n) for k, n in dist.items()] - out = torch.vstack(samples) - with torch.no_grad(): - out = self.transform(out) - return out - - def __del__(self): - del self.data - del self.targets - - -class StratifiedMeanInitializer(ClassAwareInitializer): - def __init__(self, data, **kwargs): - super().__init__(data, **kwargs) - self.initializers = get_subinitializers(self.data, self.targets, - self.clabels, MeanInitializer) - - def generate(self, length, dist): - samples = self._get_samples_from_initializer(length, dist) - return samples - - -class StratifiedSelectionInitializer(ClassAwareInitializer): - def __init__(self, data, noise=None, **kwargs): - super().__init__(data, **kwargs) - self.noise = noise - self.initializers = get_subinitializers(self.data, self.targets, - self.clabels, - SelectionInitializer) - - def add_noise_v1(self, x): - return x + self.noise - - def add_noise_v2(self, x): - """Shifts some dimensions of the data randomly.""" - n1 = torch.rand_like(x) - n2 = torch.rand_like(x) - mask = torch.bernoulli(n1) - torch.bernoulli(n2) - return x + (self.noise * mask) - - def generate(self, length, dist): - samples = self._get_samples_from_initializer(length, dist) - if self.noise is not None: - samples = self.add_noise_v1(samples) - return samples - - -# Labels -class LabelsInitializer: - def generate(self): - raise NotImplementedError("Subclasses should implement this!") - - -class UnequalLabelsInitializer(LabelsInitializer): - def __init__(self, dist, clabels=None): - self.dist = dist - self.clabels = clabels or range(len(self.dist)) - - @property - def distribution(self) -> List: - return self.dist - - def generate(self): - targets = list( - chain(*[[i] * n for i, n in zip(self.clabels, self.dist)])) - return torch.LongTensor(targets) - - -class EqualLabelsInitializer(LabelsInitializer): - def __init__(self, classes, per_class): - self.classes = classes - self.per_class = per_class - - @property - def distribution(self) -> List: - return self.classes * [self.per_class] - - def generate(self): - return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() - - -# Reasonings -class ReasoningsInitializer: - def generate(self, length): - raise NotImplementedError("Subclasses should implement this!") - - -class ZeroReasoningsInitializer(ReasoningsInitializer): - def __init__(self, classes, length): - self.classes = classes - self.length = length - - def generate(self): - return torch.zeros((self.length, self.classes, 2)) - - -# Aliases -SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer -SMI = StratifiedMeanInitializer -Random = RandomInitializer = UniformInitializer -Zeros = ZerosInitializer -Ones = OnesInitializer diff --git a/prototorch/components/labels.py b/prototorch/components/labels.py deleted file mode 100644 index bf2620d..0000000 --- a/prototorch/components/labels.py +++ /dev/null @@ -1,86 +0,0 @@ -"""ProtoTorch Labels.""" - -import torch -from prototorch.components.components import get_labels_initializer -from prototorch.components.initializers import (ClassAwareInitializer, - ComponentsInitializer, - EqualLabelsInitializer, - UnequalLabelsInitializer) -from torch.nn.parameter import Parameter - - -def get_labels_initializer(distribution): - if isinstance(distribution, dict): - if "num_classes" in distribution.keys(): - labels = EqualLabelsInitializer( - distribution["num_classes"], - distribution["prototypes_per_class"]) - else: - clabels = list(distribution.keys()) - dist = list(distribution.values()) - labels = UnequalLabelsInitializer(dist, clabels) - elif isinstance(distribution, tuple): - num_classes, prototypes_per_class = distribution - labels = EqualLabelsInitializer(num_classes, prototypes_per_class) - elif isinstance(distribution, list): - labels = UnequalLabelsInitializer(distribution) - else: - msg = f"`distribution` not understood." \ - f"You have provided: {distribution=}." - raise ValueError(msg) - return labels - - -class Labels(torch.nn.Module): - def __init__(self, - distribution=None, - initializer=None, - *, - initialized_labels=None): - _labels = self.get_labels(distribution, - initializer, - initialized_labels=initialized_labels) - self._register_labels(_labels) - - def _register_labels(self, labels): - # self.register_buffer("_labels", labels) - self.register_parameter("_labels", - Parameter(labels, requires_grad=False)) - - def get_labels(self, - distribution=None, - initializer=None, - *, - initialized_labels=None): - if initialized_labels is not None: - _labels = initialized_labels - else: - labels_initializer = initializer or get_labels_initializer( - distribution) - self.initial_distribution = labels_initializer.distribution - _labels = labels_initializer.generate() - return _labels - - def add_labels(self, - distribution=None, - initializer=None, - *, - initialized_labels=None): - new_labels = self.get_labels(distribution, - initializer, - initialized_labels=initialized_labels) - _labels = torch.cat([self._labels, new_labels]) - self._register_labels(_labels) - - def remove_labels(self, indices=None): - mask = torch.ones(len(self._labels, dtype=torch.bool)) - mask[indices] = False - _labels = self._labels[mask] - self._register_labels(_labels) - - @property - def labels(self): - return self._labels - - def forward(self): - return self._labels diff --git a/prototorch/components/__init__.py b/prototorch/core/__init__.py similarity index 76% rename from prototorch/components/__init__.py rename to prototorch/core/__init__.py index 69293cb..17be644 100644 --- a/prototorch/components/__init__.py +++ b/prototorch/core/__init__.py @@ -1,3 +1,5 @@ +"""ProtoTorch core""" + from .components import * from .initializers import * from .labels import * diff --git a/prototorch/core/components.py b/prototorch/core/components.py new file mode 100644 index 0000000..53555af --- /dev/null +++ b/prototorch/core/components.py @@ -0,0 +1,243 @@ +"""ProtoTorch components""" + +import inspect +from typing import Union + +import torch +from torch.nn.parameter import Parameter + +from ..utils import parse_distribution +from .initializers import ( + AbstractComponentsInitializer, + AbstractLabelsInitializer, + AbstractReasoningsInitializer, + ClassAwareCompInitializer, + LabelsInitializer, +) + + +def validate_initializer(initializer, instanceof): + if not isinstance(initializer, instanceof): + emsg = f"`initializer` has to be an instance " \ + f"of some subtype of {instanceof}. " \ + f"You have provided: {initializer} instead. " + helpmsg = "" + if inspect.isclass(initializer): + helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \ + f"with the brackets instead of just {initializer.__name__}?" + raise TypeError(emsg + helpmsg) + return True + + +def validate_components_initializer(initializer): + return validate_initializer(initializer, AbstractComponentsInitializer) + + +def validate_labels_initializer(initializer): + return validate_initializer(initializer, AbstractLabelsInitializer) + + +def validate_reasonings_initializer(initializer): + return validate_initializer(initializer, AbstractReasoningsInitializer) + + +class AbstractComponents(torch.nn.Module): + """Abstract class for all components modules.""" + @property + def num_components(self): + """Current number of components.""" + return len(self._components) + + @property + def components(self): + """Detached Tensor containing the components.""" + return self._components.detach() + + def _register_components(self, components): + self.register_parameter("_components", Parameter(components)) + + def extra_repr(self): + return f"(components): (shape: {tuple(self._components.shape)})" + + +class Components(AbstractComponents): + """A set of adaptable Tensors.""" + def __init__(self, num_components: int, + initializer: AbstractComponentsInitializer, **kwargs): + super().__init__(**kwargs) + self.add_components(num_components, initializer) + + def add_components(self, num: int, + initializer: AbstractComponentsInitializer): + """Add new components.""" + assert validate_components_initializer(initializer) + new_components = initializer.generate(num) + # Register + if hasattr(self, "_components"): + _components = torch.cat([self._components, new_components]) + else: + _components = new_components + self._register_components(_components) + return new_components + + def remove_components(self, indices): + """Remove components at specified indices.""" + mask = torch.ones(self.num_components, dtype=torch.bool) + mask[indices] = False + _components = self._components[mask] + self._register_components(_components) + return mask + + def forward(self): + """Simply return the components parameter Tensor.""" + return self._components + + +class LabeledComponents(AbstractComponents): + """A set of adaptable components and corresponding unadaptable labels.""" + def __init__(self, distribution: Union[dict, list, tuple], + components_initializer: AbstractComponentsInitializer, + labels_initializer: AbstractLabelsInitializer, **kwargs): + super().__init__(**kwargs) + self.add_components(distribution, components_initializer, + labels_initializer) + + @property + def component_labels(self): + """Tensor containing the component tensors.""" + return self._labels.detach() + + def _register_labels(self, labels): + self.register_buffer("_labels", labels) + + def add_components( + self, + distribution, + components_initializer, + labels_initializer: AbstractLabelsInitializer = LabelsInitializer()): + # Checks + assert validate_components_initializer(components_initializer) + assert validate_labels_initializer(labels_initializer) + + distribution = parse_distribution(distribution) + + # Generate new components + if isinstance(components_initializer, ClassAwareCompInitializer): + new_components = components_initializer.generate(distribution) + else: + num_components = sum(distribution.values()) + new_components = components_initializer.generate(num_components) + + # Generate new labels + new_labels = labels_initializer.generate(distribution) + + # Register + if hasattr(self, "_components"): + _components = torch.cat([self._components, new_components]) + else: + _components = new_components + if hasattr(self, "_labels"): + _labels = torch.cat([self._labels, new_labels]) + else: + _labels = new_labels + self._register_components(_components) + self._register_labels(_labels) + + return new_components, new_labels + + def remove_components(self, indices): + """Remove components and labels at specified indices.""" + mask = torch.ones(self.num_components, dtype=torch.bool) + mask[indices] = False + _components = self._components[mask] + _labels = self._labels[mask] + self._register_components(_components) + self._register_labels(_labels) + return mask + + def forward(self): + """Simply return the components parameter Tensor and labels.""" + return self._components, self._labels + + +class ReasoningComponents(AbstractComponents): + """A set of components and a corresponding adapatable reasoning matrices. + + Every component has its own reasoning matrix. + + A reasoning matrix is an Nx2 matrix, where N is the number of classes. The + first element is called positive reasoning :math:`p`, the second negative + reasoning :math:`n`. A components can reason in favour (positive) of a + class, against (negative) a class or not at all (neutral). + + It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 + \leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a + three element probability distribution. + + """ + def __init__(self, distribution: Union[dict, list, tuple], + components_initializer: AbstractComponentsInitializer, + reasonings_initializer: AbstractReasoningsInitializer, + **kwargs): + super().__init__(**kwargs) + self.add_components(distribution, components_initializer, + reasonings_initializer) + + @property + def reasonings(self): + """Returns Reasoning Matrix. + + Dimension NxCx2 + + """ + return self._reasonings.detach() + + def _register_reasonings(self, reasonings): + self.register_parameter("_reasonings", Parameter(reasonings)) + + def add_components(self, distribution, components_initializer, + reasonings_initializer: AbstractReasoningsInitializer): + # Checks + assert validate_components_initializer(components_initializer) + assert validate_reasonings_initializer(reasonings_initializer) + + distribution = parse_distribution(distribution) + + # Generate new components + if isinstance(components_initializer, ClassAwareCompInitializer): + new_components = components_initializer.generate(distribution) + else: + num_components = sum(distribution.values()) + new_components = components_initializer.generate(num_components) + + # Generate new reasonings + new_reasonings = reasonings_initializer.generate(distribution) + + # Register + if hasattr(self, "_components"): + _components = torch.cat([self._components, new_components]) + else: + _components = new_components + if hasattr(self, "_reasonings"): + _reasonings = torch.cat([self._reasonings, new_reasonings]) + else: + _reasonings = new_reasonings + self._register_components(_components) + self._register_reasonings(_reasonings) + + return new_components, new_reasonings + + def remove_components(self, indices): + """Remove components and labels at specified indices.""" + mask = torch.ones(self.num_components, dtype=torch.bool) + mask[indices] = False + _components = self._components[mask] + # TODO + # _reasonings = self._reasonings[mask] + self._register_components(_components) + # self._register_reasonings(_reasonings) + return mask + + def forward(self): + """Simply return the components and reasonings.""" + return self._components, self._reasonings diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py new file mode 100644 index 0000000..ba48ffd --- /dev/null +++ b/prototorch/core/initializers.py @@ -0,0 +1,258 @@ +"""ProtoTorch code initializers""" + +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Union + +import torch + +from ..utils import parse_data_arg, parse_distribution + + +# Components +class AbstractComponentsInitializer(ABC): + """Abstract class for all components initializers.""" + ... + + +class ShapeAwareCompInitializer(AbstractComponentsInitializer): + """Abstract class for all dimension-aware components initializers.""" + def __init__(self, shape: Union[Iterable, int]): + if isinstance(shape, Iterable): + self.component_shape = tuple(shape) + else: + self.component_shape = (shape, ) + + @abstractmethod + def generate(self, num_components: int): + ... + + +class DataAwareCompInitializer(AbstractComponentsInitializer): + """Abstract class for all data-aware components initializers. + + Components generated by data-aware components initializers inherit the shape + of the provided data. + + """ + def __init__(self, + data, + noise: float = 0.0, + transform: callable = torch.nn.Identity()): + self.data = data + self.noise = noise + self.transform = transform + + def generate_end_hook(self, samples): + drift = torch.rand_like(samples) * self.noise + components = self.transform(samples + drift) + return components + + @abstractmethod + def generate(self, num_components: int): + ... + return self.generate_end_hook(...) + + def __del__(self): + del self.data + + +class ClassAwareCompInitializer(AbstractComponentsInitializer): + """Abstract class for all class-aware components initializers. + + Components generated by class-aware components initializers inherit the shape + of the provided data. + + """ + def __init__(self, + data, + noise: float = 0.0, + transform: callable = torch.nn.Identity()): + self.data, self.targets = parse_data_arg(data) + self.noise = noise + self.transform = transform + self.clabels = torch.unique(self.targets).int().tolist() + self.num_classes = len(self.clabels) + + @property + @abstractmethod + def subinit_type(self) -> DataAwareCompInitializer: + ... + + def generate(self, distribution: Union[dict, list, tuple]): + distribution = parse_distribution(distribution) + initializers = { + k: self.subinit_type(self.data[self.targets == k]) + for k in distribution.keys() + } + components = torch.tensor([]) + for k, v in distribution.items(): + stratified_data = self.data[self.targets == k] + # skip transform here + initializer = self.subinit_type( + stratified_data, + noise=self.noise, + transform=self.transform, + ) + samples = initializer.generate(num_components=v) + components = torch.cat([components, samples]) + return components + + def __del__(self): + del self.data + del self.targets + + +class LiteralCompInitializer(DataAwareCompInitializer): + """'Generate' the provided components. + + Use this to 'generate' pre-initialized components from elsewhere. + + """ + def generate(self, num_components: int): + """Ignore `num_components` and simply return transformed `self.data`.""" + components = self.transform(self.data) + return components + + +class ZerosCompInitializer(ShapeAwareCompInitializer): + """Generate zeros corresponding to the components shape.""" + def generate(self, num_components: int): + components = torch.zeros((num_components, ) + self.component_shape) + return components + + +class OnesCompInitializer(ShapeAwareCompInitializer): + """Generate ones corresponding to the components shape.""" + def generate(self, num_components: int): + components = torch.ones((num_components, ) + self.component_shape) + return components + + +class FillValueCompInitializer(OnesCompInitializer): + """Generate components with the provided `fill_value`.""" + def __init__(self, shape, fill_value: float = 1.0): + super().__init__(shape) + self.fill_value = fill_value + + def generate(self, num_components: int): + ones = super().generate(num_components) + components = ones.fill_(self.fill_value) + return components + + +class UniformCompInitializer(OnesCompInitializer): + """Generate components by sampling from a continuous uniform distribution.""" + def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0): + super().__init__(shape) + self.minimum = minimum + self.maximum = maximum + self.scale = scale + + def generate(self, num_components: int): + ones = super().generate(num_components) + components = self.scale * ones.uniform_(self.minimum, self.maximum) + return components + + +class RandomNormalCompInitializer(OnesCompInitializer): + """Generate components by sampling from a standard normal distribution.""" + def __init__(self, shape, scale=1.0): + super().__init__(shape) + self.scale = scale + + def generate(self, num_components: int): + ones = super().generate(num_components) + components = self.scale * torch.randn_like(ones) + return components + + +class SelectionCompInitializer(DataAwareCompInitializer): + """Generate components by uniformly sampling from the provided data.""" + def generate(self, num_components: int): + indices = torch.LongTensor(num_components).random_(0, len(self.data)) + samples = self.data[indices] + components = self.generate_end_hook(samples) + return components + + +class MeanCompInitializer(DataAwareCompInitializer): + """Generate components by computing the mean of the provided data.""" + def generate(self, num_components: int): + mean = torch.mean(self.data, dim=0) + repeat_dim = [num_components] + [1] * len(mean.shape) + samples = mean.repeat(repeat_dim) + components = self.generate_end_hook(samples) + return components + + +class StratifiedSelectionCompInitializer(ClassAwareCompInitializer): + """Generate components using stratified sampling from the provided data.""" + @property + def subinit_type(self): + return SelectionCompInitializer + + +class StratifiedMeanCompInitializer(ClassAwareCompInitializer): + """Generate components at stratified means of the provided data.""" + @property + def subinit_type(self): + return MeanCompInitializer + + +# Labels +class AbstractLabelsInitializer(ABC): + """Abstract class for all labels initializers.""" + @abstractmethod + def generate(self, distribution: Union[dict, list, tuple]): + ... + + +class LabelsInitializer(AbstractLabelsInitializer): + """Generate labels with `self.distribution`.""" + def __init__(self, override_labels: list = []): + self.override_labels = override_labels + + def generate(self, distribution: Union[dict, list, tuple]): + distribution = parse_distribution(distribution) + labels = [] + for k, v in distribution.items(): + labels.extend([k] * v) + labels = torch.LongTensor(labels) + return labels + + +# Reasonings +class AbstractReasoningsInitializer(ABC): + """Abstract class for all reasonings initializers.""" + @abstractmethod + def generate(self, distribution: Union[dict, list, tuple]): + ... + + +class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): + """Generate labels with `self.distribution`.""" + def generate(self, distribution: Union[dict, list, tuple]): + distribution = parse_distribution(distribution) + num_classes = len(distribution.keys()) + num_components = sum(distribution.values()) + assert num_classes == num_components + reasonings = torch.stack( + [torch.eye(num_classes), + torch.zeros(num_classes, num_classes)], + dim=0) + return reasonings + + +# Aliases - Components +ZCI = ZerosCompInitializer +OCI = OnesCompInitializer +FVCI = FillValueCompInitializer +LCI = LiteralCompInitializer +UCI = UniformCompInitializer +RNCI = RandomNormalCompInitializer +SCI = SelectionCompInitializer +MCI = MeanCompInitializer +SSCI = StratifiedSelectionCompInitializer +SMCI = StratifiedMeanCompInitializer +PPRI = PurePositiveReasoningsInitializer diff --git a/prototorch/utils/utils.py b/prototorch/utils/utils.py index 316d5eb..b2058cd 100644 --- a/prototorch/utils/utils.py +++ b/prototorch/utils/utils.py @@ -23,7 +23,10 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100): return mesh, xx, yy -def parse_distribution(user_distribution: Union[dict, list, tuple]): +def parse_distribution( + user_distribution: Union[dict[int, int], dict[str, str], list[int], + tuple[int]] +) -> dict[int, int]: """Parse user-provided distribution. Return a dictionary with integer keys that represent the class labels and @@ -51,14 +54,15 @@ def parse_distribution(user_distribution: Union[dict, list, tuple]): if isinstance(user_distribution, dict): if "num_classes" in user_distribution.keys(): - num_classes = user_distribution["num_classes"] - per_class = user_distribution["per_class"] + num_classes = int(user_distribution["num_classes"]) + per_class = int(user_distribution["per_class"]) return from_list([per_class] * num_classes) else: return user_distribution elif isinstance(user_distribution, tuple): assert len(user_distribution) == 2 num_classes, per_class = user_distribution + num_classes, per_class = int(num_classes), int(per_class) return from_list([per_class] * num_classes) elif isinstance(user_distribution, list): return from_list(user_distribution) From 25dbde4e43df62a658b8eec383dfe5b51b6130ba Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 20:30:53 +0200 Subject: [PATCH 13/43] Remove tests/test_functions.py --- tests/test_functions.py | 580 ---------------------------------------- 1 file changed, 580 deletions(-) delete mode 100644 tests/test_functions.py diff --git a/tests/test_functions.py b/tests/test_functions.py deleted file mode 100644 index 91fd8a7..0000000 --- a/tests/test_functions.py +++ /dev/null @@ -1,580 +0,0 @@ -"""ProtoTorch functions test suite.""" - -import unittest - -import numpy as np -import torch -from prototorch.functions import (activations, competitions, distances, - initializers, losses, pooling) - - -class TestActivations(unittest.TestCase): - def setUp(self): - self.flist = ["identity", "sigmoid_beta", "swish_beta"] - self.x = torch.randn(1024, 1) - - def test_registry(self): - self.assertIsNotNone(activations.ACTIVATIONS) - - def test_funcname_deserialization(self): - for funcname in self.flist: - f = activations.get_activation(funcname) - iscallable = callable(f) - self.assertTrue(iscallable) - - # def test_torch_script(self): - # for funcname in self.flist: - # f = activations.get_activation(funcname) - # self.assertIsInstance(f, torch.jit.ScriptFunction) - - def test_callable_deserialization(self): - def dummy(x, **kwargs): - return x - - for f in [dummy, lambda x: x]: - f = activations.get_activation(f) - iscallable = callable(f) - self.assertTrue(iscallable) - self.assertEqual(1, f(1)) - - def test_unknown_deserialization(self): - for funcname in ["blubb", "foobar"]: - with self.assertRaises(NameError): - _ = activations.get_activation(funcname) - - def test_identity(self): - actual = activations.identity(self.x) - desired = self.x - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_sigmoid_beta1(self): - actual = activations.sigmoid_beta(self.x, beta=1.0) - desired = torch.sigmoid(self.x) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_swish_beta1(self): - actual = activations.swish_beta(self.x, beta=1.0) - desired = self.x * torch.sigmoid(self.x) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def tearDown(self): - del self.x - - -class TestCompetitions(unittest.TestCase): - def setUp(self): - pass - - def test_wtac(self): - d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) - labels = torch.tensor([0, 1, 2, 3]) - actual = competitions.wtac(d, labels) - desired = torch.tensor([2, 0]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_wtac_unequal_dist(self): - d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]]) - labels = torch.tensor([0, 1, 1]) - actual = competitions.wtac(d, labels) - desired = torch.tensor([0, 1]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_wtac_one_hot(self): - d = torch.tensor([[1.99, 3.01], [3.0, 2.01]]) - labels = torch.tensor([[0, 1], [1, 0]]) - actual = competitions.wtac(d, labels) - desired = torch.tensor([[0, 1], [1, 0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_knnc_k1(self): - d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) - labels = torch.tensor([0, 1, 2, 3]) - actual = competitions.knnc(d, labels, k=1) - desired = torch.tensor([2, 0]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def tearDown(self): - pass - - -class TestPooling(unittest.TestCase): - def setUp(self): - pass - - def test_stratified_min(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) - labels = torch.tensor([0, 0, 1, 2]) - actual = pooling.stratified_min_pooling(d, labels) - desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_min_one_hot(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) - labels = torch.tensor([0, 0, 1, 2]) - labels = torch.eye(3)[labels] - actual = pooling.stratified_min_pooling(d, labels) - desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_min_trivial(self): - d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]]) - labels = torch.tensor([0, 1, 2]) - actual = pooling.stratified_min_pooling(d, labels) - desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_max(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) - labels = torch.tensor([0, 0, 3, 2, 0]) - actual = pooling.stratified_max_pooling(d, labels) - desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_max_one_hot(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) - labels = torch.tensor([0, 0, 2, 1, 0]) - labels = torch.nn.functional.one_hot(labels, num_classes=3) - actual = pooling.stratified_max_pooling(d, labels) - desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_sum(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) - labels = torch.LongTensor([0, 0, 1, 2]) - actual = pooling.stratified_sum_pooling(d, labels) - desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_sum_one_hot(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) - labels = torch.tensor([0, 0, 1, 2]) - labels = torch.eye(3)[labels] - actual = pooling.stratified_sum_pooling(d, labels) - desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_prod(self): - d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) - labels = torch.tensor([0, 0, 3, 2, 0]) - actual = pooling.stratified_prod_pooling(d, labels) - desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def tearDown(self): - pass - - -class TestDistances(unittest.TestCase): - def setUp(self): - self.nx, self.mx = 32, 2048 - self.ny, self.my = 8, 2048 - self.x = torch.randn(self.nx, self.mx) - self.y = torch.randn(self.ny, self.my) - - def test_manhattan(self): - actual = distances.lpnorm_distance(self.x, self.y, p=1) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=1, - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=2) - self.assertIsNone(mismatch) - - def test_euclidean(self): - actual = distances.euclidean_distance(self.x, self.y) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=2, - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=3) - self.assertIsNone(mismatch) - - def test_squared_euclidean(self): - actual = distances.squared_euclidean_distance(self.x, self.y) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = (torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=2, - keepdim=False, - )**2) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=2) - self.assertIsNone(mismatch) - - def test_lpnorm_p0(self): - actual = distances.lpnorm_distance(self.x, self.y, p=0) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=0, - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=4) - self.assertIsNone(mismatch) - - def test_lpnorm_p2(self): - actual = distances.lpnorm_distance(self.x, self.y, p=2) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=2, - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=4) - self.assertIsNone(mismatch) - - def test_lpnorm_p3(self): - actual = distances.lpnorm_distance(self.x, self.y, p=3) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=3, - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=4) - self.assertIsNone(mismatch) - - def test_lpnorm_pinf(self): - actual = distances.lpnorm_distance(self.x, self.y, p=float("inf")) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=float("inf"), - keepdim=False, - ) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=4) - self.assertIsNone(mismatch) - - def test_omega_identity(self): - omega = torch.eye(self.mx, self.my) - actual = distances.omega_distance(self.x, self.y, omega=omega) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = (torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=2, - keepdim=False, - )**2) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=2) - self.assertIsNone(mismatch) - - def test_lomega_identity(self): - omega = torch.eye(self.mx, self.my) - omegas = torch.stack([omega for _ in range(self.ny)], dim=0) - actual = distances.lomega_distance(self.x, self.y, omegas=omegas) - desired = torch.empty(self.nx, self.ny) - for i in range(self.nx): - for j in range(self.ny): - desired[i][j] = (torch.nn.functional.pairwise_distance( - self.x[i].reshape(1, -1), - self.y[j].reshape(1, -1), - p=2, - keepdim=False, - )**2) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=2) - self.assertIsNone(mismatch) - - def tearDown(self): - del self.x, self.y - - -class TestInitializers(unittest.TestCase): - def setUp(self): - self.flist = [ - "zeros", - "ones", - "rand", - "randn", - "stratified_mean", - "stratified_random", - ] - self.x = torch.tensor( - [[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]], - dtype=torch.float32) - self.y = torch.tensor([0, 0, 1, 1]) - self.gen = torch.manual_seed(42) - - def test_registry(self): - self.assertIsNotNone(initializers.INITIALIZERS) - - def test_funcname_deserialization(self): - for funcname in self.flist: - f = initializers.get_initializer(funcname) - iscallable = callable(f) - self.assertTrue(iscallable) - - def test_callable_deserialization(self): - def dummy(x): - return x - - for f in [dummy, lambda x: x]: - f = initializers.get_initializer(f) - iscallable = callable(f) - self.assertTrue(iscallable) - self.assertEqual(1, f(1)) - - def test_unknown_deserialization(self): - for funcname in ["blubb", "foobar"]: - with self.assertRaises(NameError): - _ = initializers.get_initializer(funcname) - - def test_zeros(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.zeros(self.x, self.y, pdist) - desired = torch.zeros(2, 3) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_ones(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.ones(self.x, self.y, pdist) - desired = torch.ones(2, 3) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_rand(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.rand(self.x, self.y, pdist) - desired = torch.rand(2, 3, generator=torch.manual_seed(42)) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_randn(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.randn(self.x, self.y, pdist) - desired = torch.randn(2, 3, generator=torch.manual_seed(42)) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_mean_equal1(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False) - desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_random_equal1(self): - pdist = torch.tensor([1, 1]) - actual, _ = initializers.stratified_random(self.x, self.y, pdist, - False) - desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_mean_equal2(self): - pdist = torch.tensor([2, 2]) - actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False) - desired = torch.tensor([[5.0, 5.0, 5.0], [5.0, 5.0, 5.0], - [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_random_equal2(self): - pdist = torch.tensor([2, 2]) - actual, _ = initializers.stratified_random(self.x, self.y, pdist, - False) - desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, -1.0, -2.0], - [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_mean_unequal(self): - pdist = torch.tensor([1, 3]) - actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False) - desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_random_unequal(self): - pdist = torch.tensor([1, 3]) - actual, _ = initializers.stratified_random(self.x, self.y, pdist, - False) - desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) - mismatch = np.testing.assert_array_almost_equal(actual, - desired, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_mean_unequal_one_hot(self): - pdist = torch.tensor([1, 3]) - y = torch.eye(2)[self.y] - desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) - actual1, actual2 = initializers.stratified_mean(self.x, y, pdist) - desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]]) - mismatch = np.testing.assert_array_almost_equal(actual1, - desired1, - decimal=5) - mismatch = np.testing.assert_array_almost_equal(actual2, - desired2, - decimal=5) - self.assertIsNone(mismatch) - - def test_stratified_random_unequal_one_hot(self): - pdist = torch.tensor([1, 3]) - y = torch.eye(2)[self.y] - actual1, actual2 = initializers.stratified_random(self.x, y, pdist) - desired1 = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) - desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]]) - mismatch = np.testing.assert_array_almost_equal(actual1, - desired1, - decimal=5) - mismatch = np.testing.assert_array_almost_equal(actual2, - desired2, - decimal=5) - self.assertIsNone(mismatch) - - def tearDown(self): - del self.x, self.y, self.gen - _ = torch.seed() - - -class TestLosses(unittest.TestCase): - def setUp(self): - pass - - def test_glvq_loss_int_labels(self): - d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) - labels = torch.tensor([0, 1]) - targets = torch.ones(100) - batch_loss = losses.glvq_loss(distances=d, - target_labels=targets, - prototype_labels=labels) - loss_value = torch.sum(batch_loss, dim=0) - self.assertEqual(loss_value, -100) - - def test_glvq_loss_one_hot_labels(self): - d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) - labels = torch.tensor([[0, 1], [1, 0]]) - wl = torch.tensor([1, 0]) - targets = torch.stack([wl for _ in range(100)], dim=0) - batch_loss = losses.glvq_loss(distances=d, - target_labels=targets, - prototype_labels=labels) - loss_value = torch.sum(batch_loss, dim=0) - self.assertEqual(loss_value, -100) - - def test_glvq_loss_one_hot_unequal(self): - dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)] - d = torch.stack(dlist, dim=1) - labels = torch.tensor([[0, 1], [1, 0], [1, 0]]) - wl = torch.tensor([1, 0]) - targets = torch.stack([wl for _ in range(100)], dim=0) - batch_loss = losses.glvq_loss(distances=d, - target_labels=targets, - prototype_labels=labels) - loss_value = torch.sum(batch_loss, dim=0) - self.assertEqual(loss_value, -100) - - def tearDown(self): - pass From 093a79d5337ed5e575bbcfe7680e3665c59e3b35 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 20:38:16 +0200 Subject: [PATCH 14/43] [REFACTOR] Reorganize files and folders --- prototorch/__init__.py | 28 +- prototorch/core/__init__.py | 3 +- prototorch/{modules => core}/competitions.py | 31 ++- prototorch/core/distances.py | 261 +++++++++++++++++++ prototorch/core/losses.py | 151 +++++++++++ prototorch/core/pooling.py | 104 ++++++++ prototorch/modules/losses.py | 58 ----- prototorch/modules/pooling.py | 31 --- prototorch/nn/__init__.py | 4 + prototorch/nn/activations.py | 62 +++++ prototorch/{modules => nn}/wrappers.py | 2 +- 11 files changed, 633 insertions(+), 102 deletions(-) rename prototorch/{modules => core}/competitions.py (51%) create mode 100644 prototorch/core/distances.py create mode 100644 prototorch/core/losses.py create mode 100644 prototorch/core/pooling.py delete mode 100644 prototorch/modules/losses.py delete mode 100644 prototorch/modules/pooling.py create mode 100644 prototorch/nn/__init__.py create mode 100644 prototorch/nn/activations.py rename prototorch/{modules => nn}/wrappers.py (97%) diff --git a/prototorch/__init__.py b/prototorch/__init__.py index fe6a6e1..d549de2 100644 --- a/prototorch/__init__.py +++ b/prototorch/__init__.py @@ -1,20 +1,36 @@ -"""ProtoTorch package.""" +"""ProtoTorch package""" import pkgutil import pkg_resources -from . import components, datasets, functions, modules, utils -from .datasets import * +from . import ( + datasets, + nn, + utils, +) +from .core import ( + competitions, + components, + distances, + initializers, + losses, + pooling, +) # Core Setup __version__ = "0.5.0" __all_core__ = [ - "datasets", - "functions", - "modules", + "competitions", "components", + "core", + "datasets", + "distances", + "initializers", + "losses", + "nn", + "pooling", "utils", ] diff --git a/prototorch/core/__init__.py b/prototorch/core/__init__.py index 17be644..4badc95 100644 --- a/prototorch/core/__init__.py +++ b/prototorch/core/__init__.py @@ -1,5 +1,6 @@ """ProtoTorch core""" +from .competitions import * from .components import * from .initializers import * -from .labels import * +from .losses import * diff --git a/prototorch/modules/competitions.py b/prototorch/core/competitions.py similarity index 51% rename from prototorch/modules/competitions.py rename to prototorch/core/competitions.py index 585c5d6..2e354b6 100644 --- a/prototorch/modules/competitions.py +++ b/prototorch/core/competitions.py @@ -1,7 +1,31 @@ -"""ProtoTorch Competition Modules.""" +"""ProtoTorch competitions""" import torch -from prototorch.functions.competitions import knnc, wtac + + +def wtac(distances: torch.Tensor, + labels: torch.LongTensor) -> (torch.LongTensor): + """Winner-Takes-All-Competition. + + Returns the labels corresponding to the winners. + + """ + winning_indices = torch.min(distances, dim=1).indices + winning_labels = labels[winning_indices].squeeze() + return winning_labels + + +def knnc(distances: torch.Tensor, + labels: torch.LongTensor, + k: int = 1) -> (torch.LongTensor): + """K-Nearest-Neighbors-Competition. + + Returns the labels corresponding to the winners. + + """ + winning_indices = torch.topk(-distances, k=k, dim=1).indices + winning_labels = torch.mode(labels[winning_indices], dim=1).values + return winning_labels class WTAC(torch.nn.Module): @@ -10,7 +34,6 @@ class WTAC(torch.nn.Module): Thin wrapper over the `wtac` function. """ - def forward(self, distances, labels): return wtac(distances, labels) @@ -21,7 +44,6 @@ class LTAC(torch.nn.Module): Thin wrapper over the `wtac` function. """ - def forward(self, probs, labels): return wtac(-1.0 * probs, labels) @@ -32,7 +54,6 @@ class KNNC(torch.nn.Module): Thin wrapper over the `knnc` function. """ - def __init__(self, k=1, **kwargs): super().__init__(**kwargs) self.k = k diff --git a/prototorch/core/distances.py b/prototorch/core/distances.py new file mode 100644 index 0000000..0782769 --- /dev/null +++ b/prototorch/core/distances.py @@ -0,0 +1,261 @@ +"""ProtoTorch distances""" + +import numpy as np +import torch + +# from prototorch.functions.helper import ( +# _check_shapes, +# _int_and_mixed_shape, +# equal_int_shape, +# get_flat, +# ) + + +def squared_euclidean_distance(x, y): + r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`. + + Compute :math:`{\langle \bm x - \bm y \rangle}_2` + + **Alias:** + ``prototorch.functions.distances.sed`` + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + expanded_x = x.unsqueeze(dim=1) + batchwise_difference = y - expanded_x + differences_raised = torch.pow(batchwise_difference, 2) + distances = torch.sum(differences_raised, axis=2) + return distances + + +def euclidean_distance(x, y): + r"""Compute the Euclidean distance between :math:`x` and :math:`y`. + + Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}` + + :returns: Distance Tensor of shape :math:`X \times Y` + :rtype: `torch.tensor` + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + distances_raised = squared_euclidean_distance(x, y) + distances = torch.sqrt(distances_raised) + return distances + + +def euclidean_distance_v2(x, y): + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + diff = y - x.unsqueeze(1) + pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt() + # Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the + # batch diagonal. See: + # https://pytorch.org/docs/stable/generated/torch.diagonal.html + distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1) + # print(f"{diff.shape=}") # (nx, ny, ndim) + # print(f"{pairwise_distances.shape=}") # (nx, ny, ny) + # print(f"{distances.shape=}") # (nx, ny) + return distances + + +def lpnorm_distance(x, y, p): + r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`. + Also known as Minkowski distance. + + Compute :math:`{\| \bm x - \bm y \|}_p`. + + Calls ``torch.cdist`` + + :param p: p parameter of the lp norm + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + distances = torch.cdist(x, y, p=p) + return distances + + +def omega_distance(x, y, omega): + r"""Omega distance. + + Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p` + + :param `torch.tensor` omega: Two dimensional matrix + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + projected_x = x @ omega + projected_y = y @ omega + distances = squared_euclidean_distance(projected_x, projected_y) + return distances + + +def lomega_distance(x, y, omegas): + r"""Localized Omega distance. + + Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p` + + :param `torch.tensor` omegas: Three dimensional matrix + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + projected_x = x @ omegas + projected_y = torch.diagonal(y @ omegas).T + expanded_y = torch.unsqueeze(projected_y, dim=1) + batchwise_difference = expanded_y - projected_x + differences_squared = batchwise_difference**2 + distances = torch.sum(differences_squared, dim=2) + distances = distances.permute(1, 0) + return distances + + +# def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10): +# r"""Computes an euclidean distances matrix given two distinct vectors. +# last dimension must be the vector dimension! +# compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction! + +# - ``x.shape = (number_of_x_vectors, vector_dim)`` +# - ``y.shape = (number_of_y_vectors, vector_dim)`` + +# output: matrix of distances (number_of_x_vectors, number_of_y_vectors) +# """ +# for tensor in [x, y]: +# if tensor.ndim != 2: +# raise ValueError( +# "The tensor dimension must be two. You provide: tensor.ndim=" + +# str(tensor.ndim) + ".") +# if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]): +# raise ValueError( +# "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]=" +# + str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" + +# str(tuple(y.shape)[1]) + ".") + +# y = torch.transpose(y) + +# diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) + +# torch.sum(y**2, axis=0, keepdims=True)) + +# if not squared: +# if epsilon == 0: +# diss = torch.sqrt(diss) +# else: +# diss = torch.sqrt(torch.max(diss, epsilon)) + +# return diss + +# def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10): +# r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews + +# For more info about Tangen distances see + +# DOI:10.1109/IJCNN.2016.7727534. + +# The subspaces is always assumed as transposed and must be orthogonal! +# For local non sparse signals subspaces must be provided! + +# - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN +# - shape(protos): proto_number x dim1 x dim2 x ... x dimN +# - shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape) + +# subspace should be orthogonalized +# Pytorch implementation of Sascha Saralajew's tensorflow code. +# Translation by Christoph Raab +# """ +# signal_shape, signal_int_shape = _int_and_mixed_shape(signals) +# proto_shape, proto_int_shape = _int_and_mixed_shape(protos) +# subspace_int_shape = tuple(subspaces.shape) + +# # check if the shapes are correct +# _check_shapes(signal_int_shape, proto_int_shape) + +# atom_axes = list(range(3, len(signal_int_shape))) +# # for sparse signals, we use the memory efficient implementation +# if signal_int_shape[1] == 1: +# signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])]) + +# if len(atom_axes) > 1: +# protos = torch.reshape(protos, [proto_shape[0], -1]) + +# if subspaces.ndim == 2: +# # clean solution without map if the matrix_scope is global +# projectors = torch.eye(subspace_int_shape[-2]) - torch.dot( +# subspaces, torch.transpose(subspaces)) + +# projected_signals = torch.dot(signals, projectors) +# projected_protos = torch.dot(protos, projectors) + +# diss = euclidean_distance_matrix(projected_signals, +# projected_protos, +# squared=squared, +# epsilon=epsilon) + +# diss = torch.reshape( +# diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) + +# return torch.permute(diss, [0, 2, 1]) + +# else: + +# # no solution without map possible --> memory efficient but slow! +# projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm( +# subspaces, +# subspaces) # K.batch_dot(subspaces, subspaces, [2, 2]) + +# projected_protos = (protos @ subspaces +# ).T # K.batch_dot(projectors, protos, [1, 1])) + +# def projected_norm(projector): +# return torch.sum(torch.dot(signals, projector)**2, axis=1) + +# diss = (torch.transpose(map(projected_norm, projectors)) - +# 2 * torch.dot(signals, projected_protos) + +# torch.sum(projected_protos**2, axis=0, keepdims=True)) + +# if not squared: +# if epsilon == 0: +# diss = torch.sqrt(diss) +# else: +# diss = torch.sqrt(torch.max(diss, epsilon)) + +# diss = torch.reshape( +# diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) + +# return torch.permute(diss, [0, 2, 1]) + +# else: +# signals = signals.permute([0, 2, 1] + atom_axes) + +# diff = signals - protos + +# # global tangent space +# if subspaces.ndim == 2: +# # Scope Projectors +# projectors = subspaces # + +# # Scope: Tangentspace Projections +# diff = torch.reshape( +# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) +# projected_diff = diff @ projectors +# projected_diff = torch.reshape( +# projected_diff, +# (signal_shape[0], signal_shape[2], signal_shape[1]) + +# signal_shape[3:], +# ) + +# diss = torch.norm(projected_diff, 2, dim=-1) +# return diss.permute([0, 2, 1]) + +# # local tangent spaces +# else: +# # Scope: Calculate Projectors +# projectors = subspaces + +# # Scope: Tangentspace Projections +# diff = torch.reshape( +# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) +# diff = diff.permute([1, 0, 2]) +# projected_diff = torch.bmm(diff, projectors) +# projected_diff = torch.reshape( +# projected_diff, +# (signal_shape[1], signal_shape[0], signal_shape[2]) + +# signal_shape[3:], +# ) + +# diss = torch.norm(projected_diff, 2, dim=-1) +# return diss.permute([1, 0, 2]).squeeze(-1) + +# Aliases +sed = squared_euclidean_distance diff --git a/prototorch/core/losses.py b/prototorch/core/losses.py new file mode 100644 index 0000000..ab3705f --- /dev/null +++ b/prototorch/core/losses.py @@ -0,0 +1,151 @@ +"""ProtoTorch losses""" + +import torch + +from ..nn.activations import get_activation + + +# Helpers +def _get_matcher(targets, labels): + """Returns a boolean tensor.""" + matcher = torch.eq(targets.unsqueeze(dim=1), labels) + if labels.ndim == 2: + # if the labels are one-hot vectors + num_classes = targets.size()[1] + matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) + return matcher + + +def _get_dp_dm(distances, targets, plabels, with_indices=False): + """Returns the d+ and d- values for a batch of distances.""" + matcher = _get_matcher(targets, plabels) + not_matcher = torch.bitwise_not(matcher) + + inf = torch.full_like(distances, fill_value=float("inf")) + d_matching = torch.where(matcher, distances, inf) + d_unmatching = torch.where(not_matcher, distances, inf) + dp = torch.min(d_matching, dim=-1, keepdim=True) + dm = torch.min(d_unmatching, dim=-1, keepdim=True) + if with_indices: + return dp, dm + return dp.values, dm.values + + +# GLVQ +def glvq_loss(distances, target_labels, prototype_labels): + """GLVQ loss function with support for one-hot labels.""" + dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) + mu = (dp - dm) / (dp + dm) + return mu + + +def lvq1_loss(distances, target_labels, prototype_labels): + """LVQ1 loss function with support for one-hot labels. + + See Section 4 [Sado&Yamada] + https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf + """ + dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) + mu = dp + mu[dp > dm] = -dm[dp > dm] + return mu + + +def lvq21_loss(distances, target_labels, prototype_labels): + """LVQ2.1 loss function with support for one-hot labels. + + See Section 4 [Sado&Yamada] + https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf + """ + dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) + mu = dp - dm + + return mu + + +# Probabilistic +def _get_class_probabilities(probabilities, targets, prototype_labels): + # Create Label Mapping + uniques = prototype_labels.unique(sorted=True).tolist() + key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} + + target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist()))) + + whole = probabilities.sum(dim=1) + correct = probabilities[torch.arange(len(probabilities)), target_indices] + wrong = whole - correct + + return whole, correct, wrong + + +def nllr_loss(probabilities, targets, prototype_labels): + """Compute the Negative Log-Likelihood Ratio loss.""" + _, correct, wrong = _get_class_probabilities(probabilities, targets, + prototype_labels) + + likelihood = correct / wrong + log_likelihood = torch.log(likelihood) + return -1.0 * log_likelihood + + +def rslvq_loss(probabilities, targets, prototype_labels): + """Compute the Robust Soft Learning Vector Quantization (RSLVQ) loss.""" + whole, correct, _ = _get_class_probabilities(probabilities, targets, + prototype_labels) + + likelihood = correct / whole + log_likelihood = torch.log(likelihood) + return -1.0 * log_likelihood + + +class GLVQLoss(torch.nn.Module): + def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs): + super().__init__(**kwargs) + self.margin = margin + self.squashing = get_activation(squashing) + self.beta = torch.tensor(beta) + + def forward(self, outputs, targets): + distances, plabels = outputs + mu = glvq_loss(distances, targets, prototype_labels=plabels) + batch_loss = self.squashing(mu + self.margin, beta=self.beta) + return torch.sum(batch_loss, dim=0) + + +class NeuralGasEnergy(torch.nn.Module): + def __init__(self, lm, **kwargs): + super().__init__(**kwargs) + self.lm = lm + + def forward(self, d): + order = torch.argsort(d, dim=1) + ranks = torch.argsort(order, dim=1) + cost = torch.sum(self._nghood_fn(ranks, self.lm) * d) + + return cost, order + + def extra_repr(self): + return f"lambda: {self.lm}" + + @staticmethod + def _nghood_fn(rankings, lm): + return torch.exp(-rankings / lm) + + +class GrowingNeuralGasEnergy(NeuralGasEnergy): + def __init__(self, topology_layer, **kwargs): + super().__init__(**kwargs) + self.topology_layer = topology_layer + + @staticmethod + def _nghood_fn(rankings, topology): + winner = rankings[:, 0] + + weights = torch.zeros_like(rankings, dtype=torch.float) + weights[torch.arange(rankings.shape[0]), winner] = 1.0 + + neighbours = topology.get_neighbours(winner) + + weights[neighbours] = 0.1 + + return weights diff --git a/prototorch/core/pooling.py b/prototorch/core/pooling.py new file mode 100644 index 0000000..fab143f --- /dev/null +++ b/prototorch/core/pooling.py @@ -0,0 +1,104 @@ +"""ProtoTorch pooling""" + +from typing import Callable + +import torch + + +def stratify_with(values: torch.Tensor, + labels: torch.LongTensor, + fn: Callable, + fill_value: float = 0.0) -> (torch.Tensor): + """Apply an arbitrary stratification strategy on the columns on `values`. + + The outputs correspond to sorted labels. + """ + clabels = torch.unique(labels, dim=0, sorted=True) + num_classes = clabels.size()[0] + if values.size()[1] == num_classes: + # skip if stratification is trivial + return values + batch_size = values.size()[0] + winning_values = torch.zeros(num_classes, batch_size, device=labels.device) + filler = torch.full_like(values.T, fill_value=fill_value) + for i, cl in enumerate(clabels): + matcher = torch.eq(labels.unsqueeze(dim=1), cl) + if labels.ndim == 2: + # if the labels are one-hot vectors + matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) + cdists = torch.where(matcher, values.T, filler).T + winning_values[i] = fn(cdists) + if labels.ndim == 2: + # Transpose to return with `batch_size` first and + # reverse the columns to fix the ordering of the classes + return torch.flip(winning_values.T, dims=(1, )) + + return winning_values.T # return with `batch_size` first + + +def stratified_sum_pooling(values: torch.Tensor, + labels: torch.LongTensor) -> (torch.Tensor): + """Group-wise sum.""" + winning_values = stratify_with( + values, + labels, + fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(), + fill_value=0.0) + return winning_values + + +def stratified_min_pooling(values: torch.Tensor, + labels: torch.LongTensor) -> (torch.Tensor): + """Group-wise minimum.""" + winning_values = stratify_with( + values, + labels, + fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(), + fill_value=float("inf")) + return winning_values + + +def stratified_max_pooling(values: torch.Tensor, + labels: torch.LongTensor) -> (torch.Tensor): + """Group-wise maximum.""" + winning_values = stratify_with( + values, + labels, + fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(), + fill_value=-1.0 * float("inf")) + return winning_values + + +def stratified_prod_pooling(values: torch.Tensor, + labels: torch.LongTensor) -> (torch.Tensor): + """Group-wise maximum.""" + winning_values = stratify_with( + values, + labels, + fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(), + fill_value=1.0) + return winning_values + + +class StratifiedSumPooling(torch.nn.Module): + """Thin wrapper over the `stratified_sum_pooling` function.""" + def forward(self, values, labels): + return stratified_sum_pooling(values, labels) + + +class StratifiedProdPooling(torch.nn.Module): + """Thin wrapper over the `stratified_prod_pooling` function.""" + def forward(self, values, labels): + return stratified_prod_pooling(values, labels) + + +class StratifiedMinPooling(torch.nn.Module): + """Thin wrapper over the `stratified_min_pooling` function.""" + def forward(self, values, labels): + return stratified_min_pooling(values, labels) + + +class StratifiedMaxPooling(torch.nn.Module): + """Thin wrapper over the `stratified_max_pooling` function.""" + def forward(self, values, labels): + return stratified_max_pooling(values, labels) diff --git a/prototorch/modules/losses.py b/prototorch/modules/losses.py deleted file mode 100644 index 706d123..0000000 --- a/prototorch/modules/losses.py +++ /dev/null @@ -1,58 +0,0 @@ -"""ProtoTorch losses.""" - -import torch -from prototorch.functions.activations import get_activation -from prototorch.functions.losses import glvq_loss - - -class GLVQLoss(torch.nn.Module): - def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs): - super().__init__(**kwargs) - self.margin = margin - self.squashing = get_activation(squashing) - self.beta = torch.tensor(beta) - - def forward(self, outputs, targets): - distances, plabels = outputs - mu = glvq_loss(distances, targets, prototype_labels=plabels) - batch_loss = self.squashing(mu + self.margin, beta=self.beta) - return torch.sum(batch_loss, dim=0) - - -class NeuralGasEnergy(torch.nn.Module): - def __init__(self, lm, **kwargs): - super().__init__(**kwargs) - self.lm = lm - - def forward(self, d): - order = torch.argsort(d, dim=1) - ranks = torch.argsort(order, dim=1) - cost = torch.sum(self._nghood_fn(ranks, self.lm) * d) - - return cost, order - - def extra_repr(self): - return f"lambda: {self.lm}" - - @staticmethod - def _nghood_fn(rankings, lm): - return torch.exp(-rankings / lm) - - -class GrowingNeuralGasEnergy(NeuralGasEnergy): - def __init__(self, topology_layer, **kwargs): - super().__init__(**kwargs) - self.topology_layer = topology_layer - - @staticmethod - def _nghood_fn(rankings, topology): - winner = rankings[:, 0] - - weights = torch.zeros_like(rankings, dtype=torch.float) - weights[torch.arange(rankings.shape[0]), winner] = 1.0 - - neighbours = topology.get_neighbours(winner) - - weights[neighbours] = 0.1 - - return weights diff --git a/prototorch/modules/pooling.py b/prototorch/modules/pooling.py deleted file mode 100644 index eebf559..0000000 --- a/prototorch/modules/pooling.py +++ /dev/null @@ -1,31 +0,0 @@ -"""ProtoTorch Pooling Modules.""" - -import torch -from prototorch.functions.pooling import (stratified_max_pooling, - stratified_min_pooling, - stratified_prod_pooling, - stratified_sum_pooling) - - -class StratifiedSumPooling(torch.nn.Module): - """Thin wrapper over the `stratified_sum_pooling` function.""" - def forward(self, values, labels): - return stratified_sum_pooling(values, labels) - - -class StratifiedProdPooling(torch.nn.Module): - """Thin wrapper over the `stratified_prod_pooling` function.""" - def forward(self, values, labels): - return stratified_prod_pooling(values, labels) - - -class StratifiedMinPooling(torch.nn.Module): - """Thin wrapper over the `stratified_min_pooling` function.""" - def forward(self, values, labels): - return stratified_min_pooling(values, labels) - - -class StratifiedMaxPooling(torch.nn.Module): - """Thin wrapper over the `stratified_max_pooling` function.""" - def forward(self, values, labels): - return stratified_max_pooling(values, labels) diff --git a/prototorch/nn/__init__.py b/prototorch/nn/__init__.py new file mode 100644 index 0000000..bf2445e --- /dev/null +++ b/prototorch/nn/__init__.py @@ -0,0 +1,4 @@ +"""ProtoTorch Neural Network Module""" + +from .activations import * +from .wrappers import * diff --git a/prototorch/nn/activations.py b/prototorch/nn/activations.py new file mode 100644 index 0000000..7931e14 --- /dev/null +++ b/prototorch/nn/activations.py @@ -0,0 +1,62 @@ +"""ProtoTorch activations""" + +import torch + +ACTIVATIONS = dict() + + +def register_activation(fn): + """Add the activation function to the registry.""" + name = fn.__name__ + ACTIVATIONS[name] = fn + return fn + + +@register_activation +def identity(x, beta=0.0): + """Identity activation function. + + Definition: + :math:`f(x) = x` + + Keyword Arguments: + beta (`float`): Ignored. + """ + return x + + +@register_activation +def sigmoid_beta(x, beta=10.0): + r"""Sigmoid activation function with scaling. + + Definition: + :math:`f(x) = \frac{1}{1 + e^{-\beta x}}` + + Keyword Arguments: + beta (`float`): Scaling parameter :math:`\beta` + """ + out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x)) + return out + + +@register_activation +def swish_beta(x, beta=10.0): + r"""Swish activation function with scaling. + + Definition: + :math:`f(x) = \frac{x}{1 + e^{-\beta x}}` + + Keyword Arguments: + beta (`float`): Scaling parameter :math:`\beta` + """ + out = x * sigmoid_beta(x, beta=beta) + return out + + +def get_activation(funcname): + """Deserialize the activation function.""" + if callable(funcname): + return funcname + if funcname in ACTIVATIONS: + return ACTIVATIONS.get(funcname) + raise NameError(f"Activation {funcname} was not found.") diff --git a/prototorch/modules/wrappers.py b/prototorch/nn/wrappers.py similarity index 97% rename from prototorch/modules/wrappers.py rename to prototorch/nn/wrappers.py index da94c52..c3fe781 100644 --- a/prototorch/modules/wrappers.py +++ b/prototorch/nn/wrappers.py @@ -1,4 +1,4 @@ -"""ProtoTorch Wrappers.""" +"""ProtoTorch wrappers.""" import torch From a30672b932a446397252f029debc8db189c0ae70 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 20:39:47 +0200 Subject: [PATCH 15/43] Temporarily remove GTLVQ --- prototorch/modules/models.py | 169 ----------------------------------- 1 file changed, 169 deletions(-) delete mode 100644 prototorch/modules/models.py diff --git a/prototorch/modules/models.py b/prototorch/modules/models.py deleted file mode 100644 index f7a2d2c..0000000 --- a/prototorch/modules/models.py +++ /dev/null @@ -1,169 +0,0 @@ -import torch -from prototorch.components import LabeledComponents, StratifiedMeanInitializer -from prototorch.functions.distances import euclidean_distance_matrix -from prototorch.functions.normalization import orthogonalization -from torch import nn - - -class GTLVQ(nn.Module): - r""" Generalized Tangent Learning Vector Quantization - - Parameters - ---------- - num_classes: int - Number of classes of the given classification problem. - - subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim) - Subspace data for the point approximation, required - - prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional) - prototype data for initalization of the prototypes used in GTLVQ. - - subspace_size: int (default=256,optional) - Subspace dimension of the Projectors. Currently only supported - with tagnent_projection_type=global. - - tangent_projection_type: string - Specifies the tangent projection type - options: local - local_proj - global - local: computes the tangent distances without emphasizing projected - data. Only distances are available - local_proj: computs tangent distances and returns the projected data - for further use. Be careful: data is repeated by number of prototypes - global: Number of subspaces is set to one and every prototypes - uses the same. - - prototypes_per_class: int (default=2,optional) - Number of prototypes per class - - feature_dim: int (default=256) - Dimensionality of the feature space specified as integer. - Prototype dimension. - - Notes - ----- - The GTLVQ [1] is a prototype-based classification learning model. The - GTLVQ uses the Tangent-Distances for a local point approximation - of an assumed data manifold via prototypial representations. - - The GTLVQ requires subspace projectors for transforming the data - and prototypes into the affine subspace. Every prototype is - equipped with a specific subpspace and represents a point - approximation of the assumed manifold. - - In practice prototypes and data are projected on this manifold - and pairwise euclidean distance computes. - - References - ---------- - .. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning - in classification based on manifolc. models and its relation - to tangent metric learning. In: 2017 International Joint - Conference on Neural Networks (IJCNN). - Bd. 2017-May : IEEE, 2017, S. 1756–1765 - """ - def __init__( - self, - num_classes, - subspace_data=None, - prototype_data=None, - subspace_size=256, - tangent_projection_type="local", - prototypes_per_class=2, - feature_dim=256, - ): - super(GTLVQ, self).__init__() - - self.num_protos = num_classes * prototypes_per_class - self.num_protos_class = prototypes_per_class - self.subspace_size = feature_dim if subspace_size is None else subspace_size - self.feature_dim = feature_dim - self.num_classes = num_classes - - cls_initializer = StratifiedMeanInitializer(prototype_data) - cls_distribution = { - "num_classes": num_classes, - "prototypes_per_class": prototypes_per_class, - } - - self.cls = LabeledComponents(cls_distribution, cls_initializer) - - if subspace_data is None: - raise ValueError("Init Data must be specified!") - - self.tpt = tangent_projection_type - with torch.no_grad(): - if self.tpt == "local": - self.init_local_subspace(subspace_data, subspace_size, - self.num_protos) - elif self.tpt == "global": - self.init_gobal_subspace(subspace_data, subspace_size) - else: - self.subspaces = None - - def forward(self, x): - if self.tpt == "local": - dis = self.local_tangent_distances(x) - elif self.tpt == "gloabl": - dis = self.global_tangent_distances(x) - else: - dis = (x @ self.cls.prototypes.T) / ( - torch.norm(x, dim=1, keepdim=True) @ torch.norm( - self.cls.prototypes, dim=1, keepdim=True).T) - return dis - - def init_gobal_subspace(self, data, num_subspaces): - _, _, v = torch.svd(data) - subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T - subspaces = subspace[:, :num_subspaces] - self.subspaces = nn.Parameter(subspaces, requires_grad=True) - - def init_local_subspace(self, data, num_subspaces, num_protos): - data = data - torch.mean(data, dim=0) - _, _, v = torch.svd(data, some=False) - v = v[:, :num_subspaces] - subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0) - self.subspaces = nn.Parameter(subspaces, requires_grad=True) - - def global_tangent_distances(self, x): - # Tangent Projection - x, projected_prototypes = ( - x @ self.subspaces, - self.cls.prototypes @ self.subspaces, - ) - # Euclidean Distance - return euclidean_distance_matrix(x, projected_prototypes) - - def local_tangent_distances(self, x): - - # Tangent Distance - x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components, - x.size(-1)) - protos = self.cls()[0].unsqueeze(0).expand(x.size(0), - self.cls.num_components, - x.size(-1)) - projectors = torch.eye( - self.subspaces.shape[-2], device=x.device) - torch.bmm( - self.subspaces, self.subspaces.permute([0, 2, 1])) - diff = (x - protos) - diff = diff.permute([1, 0, 2]) - diff = torch.bmm(diff, projectors) - diff = torch.norm(diff, 2, dim=-1).T - return diff - - def get_parameters(self): - return { - "params": self.cls.components, - }, { - "params": self.subspaces - } - - def orthogonalize_subspace(self): - if self.subspaces is not None: - with torch.no_grad(): - ortho_subpsaces = (orthogonalization(self.subspaces) - if self.tpt == "global" else - torch.nn.init.orthogonal_(self.subspaces)) - self.subspaces.copy_(ortho_subpsaces) From 1ba7f5c4f76771c1df9e55648c131dc739f8e81e Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 20:40:23 +0200 Subject: [PATCH 16/43] Add core test suite --- tests/test_core.py | 552 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 552 insertions(+) create mode 100644 tests/test_core.py diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..d2496c8 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,552 @@ +"""ProtoTorch core test suite""" + +import unittest + +import numpy as np +import pytest +import torch + +import prototorch as pt +from prototorch.utils import parse_distribution + + +# Utils +def test_parse_distribution_dict_0(): + distribution = {"num_classes": 1, "per_class": 0} + distribution = parse_distribution(distribution) + assert distribution == {0: 0} + + +def test_parse_distribution_dict_1(): + distribution = dict(num_classes=3, per_class=2) + distribution = parse_distribution(distribution) + assert distribution == {0: 2, 1: 2, 2: 2} + + +def test_parse_distribution_dict_2(): + distribution = {0: 1, 2: 2, -1: 3} + distribution = parse_distribution(distribution) + assert distribution == {0: 1, 2: 2, -1: 3} + + +def test_parse_distribution_tuple(): + distribution = (2, 3) + distribution = parse_distribution(distribution) + assert distribution == {0: 3, 1: 3} + + +def test_parse_distribution_list(): + distribution = [1, 1, 0, 2] + distribution = parse_distribution(distribution) + assert distribution == {0: 1, 1: 1, 2: 0, 3: 2} + + +# Components initializers +def test_shape_aware_raises_error(): + with pytest.raises(TypeError): + _ = pt.initializers.ShapeAwareCompInitializer(shape=(2, )) + + +def test_literal_comp_generate(): + protos = torch.rand(4, 3, 5, 5) + c = pt.initializers.LiteralCompInitializer(protos) + components = c.generate(num_components="IgnoreMe!") + assert torch.allclose(components, protos) + + +def test_zeros_comp_generate(): + shape = (3, 5, 5) + c = pt.initializers.ZerosCompInitializer(shape) + components = c.generate(num_components=4) + assert torch.allclose(components, torch.zeros(4, 3, 5, 5)) + + +def test_ones_comp_generate(): + c = pt.initializers.OnesCompInitializer(2) + components = c.generate(num_components=3) + assert torch.allclose(components, torch.ones(3, 2)) + + +def test_fill_value_comp_generate(): + c = pt.initializers.FillValueCompInitializer(2, 0.0) + components = c.generate(num_components=3) + assert torch.allclose(components, torch.zeros(3, 2)) + + +def test_comp_generate_0_components(): + c = pt.initializers.ZerosCompInitializer(2) + _ = c.generate(num_components=0) + + +def test_stratified_mean_comp_generate(): + # yapf: disable + x = torch.Tensor( + [[0, -1, -2], + [10, 11, 12], + [0, 0, 0], + [2, 2, 2]]) + y = torch.LongTensor([0, 0, 1, 1]) + desired = torch.Tensor( + [[5.0, 5.0, 5.0], + [1.0, 1.0, 1.0]]) + # yapf: enable + c = pt.initializers.StratifiedMeanCompInitializer(data=[x, y]) + actual = c.generate([1, 1]) + assert torch.allclose(actual, desired) + + +def test_stratified_selection_comp_generate(): + # yapf: disable + x = torch.Tensor( + [[0, 0, 0], + [1, 1, 1], + [0, 0, 0], + [1, 1, 1]]) + y = torch.LongTensor([0, 1, 0, 1]) + desired = torch.Tensor( + [[0, 0, 0], + [1, 1, 1]]) + # yapf: enable + c = pt.initializers.StratifiedSelectionCompInitializer(data=[x, y]) + actual = c.generate([1, 1]) + assert torch.allclose(actual, desired) + + +# Labels initializers +def test_labels_init_from_list(): + l = pt.initializers.LabelsInitializer() + components = l.generate(distribution=[1, 1, 1]) + assert torch.allclose(components, torch.LongTensor([0, 1, 2])) + + +def test_labels_init_from_tuple_legal(): + l = pt.initializers.LabelsInitializer() + components = l.generate(distribution=(3, 1)) + assert torch.allclose(components, torch.LongTensor([0, 1, 2])) + + +def test_labels_init_from_tuple_illegal(): + l = pt.initializers.LabelsInitializer() + with pytest.raises(AssertionError): + _ = l.generate(distribution=(1, 1, 1)) + + +# Components +def test_components_no_initializer(): + with pytest.raises(TypeError): + _ = pt.components.Components(3, None) + + +def test_components_no_num_components(): + with pytest.raises(TypeError): + _ = pt.components.Components(initializer=pt.initializers.OCI(2)) + + +def test_components_none_num_components(): + with pytest.raises(TypeError): + _ = pt.components.Components(None, initializer=pt.initializers.OCI(2)) + + +def test_components_no_args(): + with pytest.raises(TypeError): + _ = pt.components.Components() + + +def test_components_zeros_init(): + c = pt.components.Components(3, pt.initializers.ZCI(2)) + assert torch.allclose(c.components, torch.zeros(3, 2)) + + +# Losses +def test_glvq_loss_int_labels(): + d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) + labels = torch.tensor([0, 1]) + targets = torch.ones(100) + batch_loss = pt.losses.glvq_loss(distances=d, + target_labels=targets, + prototype_labels=labels) + loss_value = torch.sum(batch_loss, dim=0) + assert loss_value == -100 + + +def test_glvq_loss_one_hot_labels(): + d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) + labels = torch.tensor([[0, 1], [1, 0]]) + wl = torch.tensor([1, 0]) + targets = torch.stack([wl for _ in range(100)], dim=0) + batch_loss = pt.losses.glvq_loss(distances=d, + target_labels=targets, + prototype_labels=labels) + loss_value = torch.sum(batch_loss, dim=0) + assert loss_value == -100 + + +def test_glvq_loss_one_hot_unequal(): + dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)] + d = torch.stack(dlist, dim=1) + labels = torch.tensor([[0, 1], [1, 0], [1, 0]]) + wl = torch.tensor([1, 0]) + targets = torch.stack([wl for _ in range(100)], dim=0) + batch_loss = pt.losses.glvq_loss(distances=d, + target_labels=targets, + prototype_labels=labels) + loss_value = torch.sum(batch_loss, dim=0) + assert loss_value == -100 + + +# Activations +class TestActivations(unittest.TestCase): + def setUp(self): + self.flist = ["identity", "sigmoid_beta", "swish_beta"] + self.x = torch.randn(1024, 1) + + def test_registry(self): + self.assertIsNotNone(pt.nn.ACTIVATIONS) + + def test_funcname_deserialization(self): + for funcname in self.flist: + f = pt.nn.get_activation(funcname) + iscallable = callable(f) + self.assertTrue(iscallable) + + def test_callable_deserialization(self): + def dummy(x, **kwargs): + return x + + for f in [dummy, lambda x: x]: + f = pt.nn.get_activation(f) + iscallable = callable(f) + self.assertTrue(iscallable) + self.assertEqual(1, f(1)) + + def test_unknown_deserialization(self): + for funcname in ["blubb", "foobar"]: + with self.assertRaises(NameError): + _ = pt.nn.get_activation(funcname) + + def test_identity(self): + actual = pt.nn.identity(self.x) + desired = self.x + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_sigmoid_beta1(self): + actual = pt.nn.sigmoid_beta(self.x, beta=1.0) + desired = torch.sigmoid(self.x) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_swish_beta1(self): + actual = pt.nn.swish_beta(self.x, beta=1.0) + desired = self.x * torch.sigmoid(self.x) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def tearDown(self): + del self.x + + +# Competitions +class TestCompetitions(unittest.TestCase): + def setUp(self): + pass + + def test_wtac(self): + d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) + labels = torch.tensor([0, 1, 2, 3]) + actual = pt.competitions.wtac(d, labels) + desired = torch.tensor([2, 0]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_wtac_unequal_dist(self): + d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]]) + labels = torch.tensor([0, 1, 1]) + actual = pt.competitions.wtac(d, labels) + desired = torch.tensor([0, 1]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_wtac_one_hot(self): + d = torch.tensor([[1.99, 3.01], [3.0, 2.01]]) + labels = torch.tensor([[0, 1], [1, 0]]) + actual = pt.competitions.wtac(d, labels) + desired = torch.tensor([[0, 1], [1, 0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_knnc_k1(self): + d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) + labels = torch.tensor([0, 1, 2, 3]) + actual = pt.competitions.knnc(d, labels, k=1) + desired = torch.tensor([2, 0]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def tearDown(self): + pass + + +# Pooling +class TestPooling(unittest.TestCase): + def setUp(self): + pass + + def test_stratified_min(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) + labels = torch.tensor([0, 0, 1, 2]) + actual = pt.pooling.stratified_min_pooling(d, labels) + desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_min_one_hot(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) + labels = torch.tensor([0, 0, 1, 2]) + labels = torch.eye(3)[labels] + actual = pt.pooling.stratified_min_pooling(d, labels) + desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_min_trivial(self): + d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]]) + labels = torch.tensor([0, 1, 2]) + actual = pt.pooling.stratified_min_pooling(d, labels) + desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_max(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) + labels = torch.tensor([0, 0, 3, 2, 0]) + actual = pt.pooling.stratified_max_pooling(d, labels) + desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_max_one_hot(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) + labels = torch.tensor([0, 0, 2, 1, 0]) + labels = torch.nn.functional.one_hot(labels, num_classes=3) + actual = pt.pooling.stratified_max_pooling(d, labels) + desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_sum(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) + labels = torch.LongTensor([0, 0, 1, 2]) + actual = pt.pooling.stratified_sum_pooling(d, labels) + desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_sum_one_hot(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) + labels = torch.tensor([0, 0, 1, 2]) + labels = torch.eye(3)[labels] + actual = pt.pooling.stratified_sum_pooling(d, labels) + desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def test_stratified_prod(self): + d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) + labels = torch.tensor([0, 0, 3, 2, 0]) + actual = pt.pooling.stratified_prod_pooling(d, labels) + desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]]) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=5) + self.assertIsNone(mismatch) + + def tearDown(self): + pass + + +# Distances +class TestDistances(unittest.TestCase): + def setUp(self): + self.nx, self.mx = 32, 2048 + self.ny, self.my = 8, 2048 + self.x = torch.randn(self.nx, self.mx) + self.y = torch.randn(self.ny, self.my) + + def test_manhattan(self): + actual = pt.distances.lpnorm_distance(self.x, self.y, p=1) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=1, + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=2) + self.assertIsNone(mismatch) + + def test_euclidean(self): + actual = pt.distances.euclidean_distance(self.x, self.y) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=2, + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=3) + self.assertIsNone(mismatch) + + def test_squared_euclidean(self): + actual = pt.distances.squared_euclidean_distance(self.x, self.y) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = (torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=2, + keepdim=False, + )**2) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=2) + self.assertIsNone(mismatch) + + def test_lpnorm_p0(self): + actual = pt.distances.lpnorm_distance(self.x, self.y, p=0) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=0, + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=4) + self.assertIsNone(mismatch) + + def test_lpnorm_p2(self): + actual = pt.distances.lpnorm_distance(self.x, self.y, p=2) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=2, + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=4) + self.assertIsNone(mismatch) + + def test_lpnorm_p3(self): + actual = pt.distances.lpnorm_distance(self.x, self.y, p=3) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=3, + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=4) + self.assertIsNone(mismatch) + + def test_lpnorm_pinf(self): + actual = pt.distances.lpnorm_distance(self.x, self.y, p=float("inf")) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=float("inf"), + keepdim=False, + ) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=4) + self.assertIsNone(mismatch) + + def test_omega_identity(self): + omega = torch.eye(self.mx, self.my) + actual = pt.distances.omega_distance(self.x, self.y, omega=omega) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = (torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=2, + keepdim=False, + )**2) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=2) + self.assertIsNone(mismatch) + + def test_lomega_identity(self): + omega = torch.eye(self.mx, self.my) + omegas = torch.stack([omega for _ in range(self.ny)], dim=0) + actual = pt.distances.lomega_distance(self.x, self.y, omegas=omegas) + desired = torch.empty(self.nx, self.ny) + for i in range(self.nx): + for j in range(self.ny): + desired[i][j] = (torch.nn.functional.pairwise_distance( + self.x[i].reshape(1, -1), + self.y[j].reshape(1, -1), + p=2, + keepdim=False, + )**2) + mismatch = np.testing.assert_array_almost_equal(actual, + desired, + decimal=2) + self.assertIsNone(mismatch) + + def tearDown(self): + del self.x, self.y From 38244f6691cf31834eb8be201f20654a0ab6e25a Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 20:41:00 +0200 Subject: [PATCH 17/43] Add setup.cfg --- setup.cfg | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 setup.cfg diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..33c1a02 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,15 @@ +[pylint] +disable = + too-many-arguments, + too-few-public-methods, + fixme, + +[pycodestyle] +max-line-length = 79 + +[isort] +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 3 +use_parentheses = True +line_length = 79 \ No newline at end of file From b4ad870b7c2b92f5b149f73b31c2fa1f0f9e654a Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 20:48:09 +0200 Subject: [PATCH 18/43] Remove prototorch/functions and prototorch/modules --- prototorch/functions/__init__.py | 5 - prototorch/functions/activations.py | 62 ------- prototorch/functions/competitions.py | 28 --- prototorch/functions/distances.py | 258 -------------------------- prototorch/functions/helper.py | 94 ---------- prototorch/functions/initializers.py | 107 ----------- prototorch/functions/losses.py | 94 ---------- prototorch/functions/normalization.py | 35 ---- prototorch/functions/pooling.py | 80 -------- prototorch/functions/similarities.py | 18 -- prototorch/functions/transforms.py | 32 ---- prototorch/modules/__init__.py | 5 - 12 files changed, 818 deletions(-) delete mode 100644 prototorch/functions/__init__.py delete mode 100644 prototorch/functions/activations.py delete mode 100644 prototorch/functions/competitions.py delete mode 100644 prototorch/functions/distances.py delete mode 100644 prototorch/functions/helper.py delete mode 100644 prototorch/functions/initializers.py delete mode 100644 prototorch/functions/losses.py delete mode 100644 prototorch/functions/normalization.py delete mode 100644 prototorch/functions/pooling.py delete mode 100644 prototorch/functions/similarities.py delete mode 100644 prototorch/functions/transforms.py delete mode 100644 prototorch/modules/__init__.py diff --git a/prototorch/functions/__init__.py b/prototorch/functions/__init__.py deleted file mode 100644 index 9b3b993..0000000 --- a/prototorch/functions/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""ProtoTorch functions.""" - -from .activations import identity, sigmoid_beta, swish_beta -from .competitions import knnc, wtac -from .pooling import * diff --git a/prototorch/functions/activations.py b/prototorch/functions/activations.py deleted file mode 100644 index c5673ae..0000000 --- a/prototorch/functions/activations.py +++ /dev/null @@ -1,62 +0,0 @@ -"""ProtoTorch activation functions.""" - -import torch - -ACTIVATIONS = dict() - - -def register_activation(fn): - """Add the activation function to the registry.""" - name = fn.__name__ - ACTIVATIONS[name] = fn - return fn - - -@register_activation -def identity(x, beta=0.0): - """Identity activation function. - - Definition: - :math:`f(x) = x` - - Keyword Arguments: - beta (`float`): Ignored. - """ - return x - - -@register_activation -def sigmoid_beta(x, beta=10.0): - r"""Sigmoid activation function with scaling. - - Definition: - :math:`f(x) = \frac{1}{1 + e^{-\beta x}}` - - Keyword Arguments: - beta (`float`): Scaling parameter :math:`\beta` - """ - out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x)) - return out - - -@register_activation -def swish_beta(x, beta=10.0): - r"""Swish activation function with scaling. - - Definition: - :math:`f(x) = \frac{x}{1 + e^{-\beta x}}` - - Keyword Arguments: - beta (`float`): Scaling parameter :math:`\beta` - """ - out = x * sigmoid_beta(x, beta=beta) - return out - - -def get_activation(funcname): - """Deserialize the activation function.""" - if callable(funcname): - return funcname - if funcname in ACTIVATIONS: - return ACTIVATIONS.get(funcname) - raise NameError(f"Activation {funcname} was not found.") diff --git a/prototorch/functions/competitions.py b/prototorch/functions/competitions.py deleted file mode 100644 index 326d510..0000000 --- a/prototorch/functions/competitions.py +++ /dev/null @@ -1,28 +0,0 @@ -"""ProtoTorch competition functions.""" - -import torch - - -def wtac(distances: torch.Tensor, - labels: torch.LongTensor) -> (torch.LongTensor): - """Winner-Takes-All-Competition. - - Returns the labels corresponding to the winners. - - """ - winning_indices = torch.min(distances, dim=1).indices - winning_labels = labels[winning_indices].squeeze() - return winning_labels - - -def knnc(distances: torch.Tensor, - labels: torch.LongTensor, - k: int = 1) -> (torch.LongTensor): - """K-Nearest-Neighbors-Competition. - - Returns the labels corresponding to the winners. - - """ - winning_indices = torch.topk(-distances, k=k, dim=1).indices - winning_labels = torch.mode(labels[winning_indices], dim=1).values - return winning_labels diff --git a/prototorch/functions/distances.py b/prototorch/functions/distances.py deleted file mode 100644 index 5bea3c4..0000000 --- a/prototorch/functions/distances.py +++ /dev/null @@ -1,258 +0,0 @@ -"""ProtoTorch distance functions.""" - -import numpy as np -import torch -from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape, - equal_int_shape, get_flat) - - -def squared_euclidean_distance(x, y): - r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`. - - Compute :math:`{\langle \bm x - \bm y \rangle}_2` - - **Alias:** - ``prototorch.functions.distances.sed`` - """ - x, y = get_flat(x, y) - expanded_x = x.unsqueeze(dim=1) - batchwise_difference = y - expanded_x - differences_raised = torch.pow(batchwise_difference, 2) - distances = torch.sum(differences_raised, axis=2) - return distances - - -def euclidean_distance(x, y): - r"""Compute the Euclidean distance between :math:`x` and :math:`y`. - - Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}` - - :returns: Distance Tensor of shape :math:`X \times Y` - :rtype: `torch.tensor` - """ - x, y = get_flat(x, y) - distances_raised = squared_euclidean_distance(x, y) - distances = torch.sqrt(distances_raised) - return distances - - -def euclidean_distance_v2(x, y): - x, y = get_flat(x, y) - diff = y - x.unsqueeze(1) - pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt() - # Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the - # batch diagonal. See: - # https://pytorch.org/docs/stable/generated/torch.diagonal.html - distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1) - # print(f"{diff.shape=}") # (nx, ny, ndim) - # print(f"{pairwise_distances.shape=}") # (nx, ny, ny) - # print(f"{distances.shape=}") # (nx, ny) - return distances - - -def lpnorm_distance(x, y, p): - r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`. - Also known as Minkowski distance. - - Compute :math:`{\| \bm x - \bm y \|}_p`. - - Calls ``torch.cdist`` - - :param p: p parameter of the lp norm - """ - x, y = get_flat(x, y) - distances = torch.cdist(x, y, p=p) - return distances - - -def omega_distance(x, y, omega): - r"""Omega distance. - - Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p` - - :param `torch.tensor` omega: Two dimensional matrix - """ - x, y = get_flat(x, y) - projected_x = x @ omega - projected_y = y @ omega - distances = squared_euclidean_distance(projected_x, projected_y) - return distances - - -def lomega_distance(x, y, omegas): - r"""Localized Omega distance. - - Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p` - - :param `torch.tensor` omegas: Three dimensional matrix - """ - x, y = get_flat(x, y) - projected_x = x @ omegas - projected_y = torch.diagonal(y @ omegas).T - expanded_y = torch.unsqueeze(projected_y, dim=1) - batchwise_difference = expanded_y - projected_x - differences_squared = batchwise_difference**2 - distances = torch.sum(differences_squared, dim=2) - distances = distances.permute(1, 0) - return distances - - -def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10): - r"""Computes an euclidean distances matrix given two distinct vectors. - last dimension must be the vector dimension! - compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction! - - - ``x.shape = (number_of_x_vectors, vector_dim)`` - - ``y.shape = (number_of_y_vectors, vector_dim)`` - - output: matrix of distances (number_of_x_vectors, number_of_y_vectors) - """ - for tensor in [x, y]: - if tensor.ndim != 2: - raise ValueError( - "The tensor dimension must be two. You provide: tensor.ndim=" + - str(tensor.ndim) + ".") - if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]): - raise ValueError( - "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]=" - + str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" + - str(tuple(y.shape)[1]) + ".") - - y = torch.transpose(y) - - diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) + - torch.sum(y**2, axis=0, keepdims=True)) - - if not squared: - if epsilon == 0: - diss = torch.sqrt(diss) - else: - diss = torch.sqrt(torch.max(diss, epsilon)) - - return diss - - -def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10): - r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews - - For more info about Tangen distances see - - DOI:10.1109/IJCNN.2016.7727534. - - The subspaces is always assumed as transposed and must be orthogonal! - For local non sparse signals subspaces must be provided! - - - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN - - shape(protos): proto_number x dim1 x dim2 x ... x dimN - - shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape) - - subspace should be orthogonalized - Pytorch implementation of Sascha Saralajew's tensorflow code. - Translation by Christoph Raab - """ - signal_shape, signal_int_shape = _int_and_mixed_shape(signals) - proto_shape, proto_int_shape = _int_and_mixed_shape(protos) - subspace_int_shape = tuple(subspaces.shape) - - # check if the shapes are correct - _check_shapes(signal_int_shape, proto_int_shape) - - atom_axes = list(range(3, len(signal_int_shape))) - # for sparse signals, we use the memory efficient implementation - if signal_int_shape[1] == 1: - signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])]) - - if len(atom_axes) > 1: - protos = torch.reshape(protos, [proto_shape[0], -1]) - - if subspaces.ndim == 2: - # clean solution without map if the matrix_scope is global - projectors = torch.eye(subspace_int_shape[-2]) - torch.dot( - subspaces, torch.transpose(subspaces)) - - projected_signals = torch.dot(signals, projectors) - projected_protos = torch.dot(protos, projectors) - - diss = euclidean_distance_matrix(projected_signals, - projected_protos, - squared=squared, - epsilon=epsilon) - - diss = torch.reshape( - diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) - - return torch.permute(diss, [0, 2, 1]) - - else: - - # no solution without map possible --> memory efficient but slow! - projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm( - subspaces, - subspaces) # K.batch_dot(subspaces, subspaces, [2, 2]) - - projected_protos = (protos @ subspaces - ).T # K.batch_dot(projectors, protos, [1, 1])) - - def projected_norm(projector): - return torch.sum(torch.dot(signals, projector)**2, axis=1) - - diss = (torch.transpose(map(projected_norm, projectors)) - - 2 * torch.dot(signals, projected_protos) + - torch.sum(projected_protos**2, axis=0, keepdims=True)) - - if not squared: - if epsilon == 0: - diss = torch.sqrt(diss) - else: - diss = torch.sqrt(torch.max(diss, epsilon)) - - diss = torch.reshape( - diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) - - return torch.permute(diss, [0, 2, 1]) - - else: - signals = signals.permute([0, 2, 1] + atom_axes) - - diff = signals - protos - - # global tangent space - if subspaces.ndim == 2: - # Scope Projectors - projectors = subspaces # - - # Scope: Tangentspace Projections - diff = torch.reshape( - diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) - projected_diff = diff @ projectors - projected_diff = torch.reshape( - projected_diff, - (signal_shape[0], signal_shape[2], signal_shape[1]) + - signal_shape[3:], - ) - - diss = torch.norm(projected_diff, 2, dim=-1) - return diss.permute([0, 2, 1]) - - # local tangent spaces - else: - # Scope: Calculate Projectors - projectors = subspaces - - # Scope: Tangentspace Projections - diff = torch.reshape( - diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) - diff = diff.permute([1, 0, 2]) - projected_diff = torch.bmm(diff, projectors) - projected_diff = torch.reshape( - projected_diff, - (signal_shape[1], signal_shape[0], signal_shape[2]) + - signal_shape[3:], - ) - - diss = torch.norm(projected_diff, 2, dim=-1) - return diss.permute([1, 0, 2]).squeeze(-1) - - -# Aliases -sed = squared_euclidean_distance diff --git a/prototorch/functions/helper.py b/prototorch/functions/helper.py deleted file mode 100644 index 6797a72..0000000 --- a/prototorch/functions/helper.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch - - -def get_flat(*args): - rv = [x.view(x.size(0), -1) for x in args] - return rv - - -def calculate_prototype_accuracy(y_pred, y_true, plabels): - """Computes the accuracy of a prototype based model. - via Winner-Takes-All rule. - Requirement: - y_pred.shape == y_true.shape - unique(y_pred) in plabels - """ - with torch.no_grad(): - idx = torch.argmin(y_pred, axis=1) - return torch.true_divide(torch.sum(y_true == plabels[idx]), - len(y_pred)) * 100 - - -def predict_label(y_pred, plabels): - r""" Predicts labels given a prediction of a prototype based model. - """ - with torch.no_grad(): - return plabels[torch.argmin(y_pred, 1)] - - -def mixed_shape(inputs): - if not torch.is_tensor(inputs): - raise ValueError("Input must be a tensor.") - else: - int_shape = list(inputs.shape) - # sometimes int_shape returns mixed integer types - int_shape = [int(i) if i is not None else i for i in int_shape] - tensor_shape = inputs.shape - - for i, s in enumerate(int_shape): - if s is None: - int_shape[i] = tensor_shape[i] - return tuple(int_shape) - - -def equal_int_shape(shape_1, shape_2): - if not isinstance(shape_1, - (tuple, list)) or not isinstance(shape_2, (tuple, list)): - raise ValueError("Input shapes must list or tuple.") - for shape in [shape_1, shape_2]: - if not all([isinstance(x, int) or x is None for x in shape]): - raise ValueError( - "Input shapes must be list or tuple of int and None values.") - - if len(shape_1) != len(shape_2): - return False - else: - for axis, value in enumerate(shape_1): - if value is not None and shape_2[axis] not in {value, None}: - return False - return True - - -def _check_shapes(signal_int_shape, proto_int_shape): - if len(signal_int_shape) < 4: - raise ValueError( - "The number of signal dimensions must be >=4. You provide: " + - str(len(signal_int_shape))) - - if len(proto_int_shape) < 2: - raise ValueError( - "The number of proto dimensions must be >=2. You provide: " + - str(len(proto_int_shape))) - - if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]): - raise ValueError( - "The atom shape of signals must be equal protos. You provide: signals.shape[3:]=" - + str(signal_int_shape[3:]) + " != protos.shape[1:]=" + - str(proto_int_shape[1:])) - - # not a sparse signal - if signal_int_shape[1] != 1: - if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]): - raise ValueError( - "If the signal is not sparse, the number of prototypes must be equal in signals and " - "protos. You provide: " + str(signal_int_shape[1]) + " != " + - str(proto_int_shape[0])) - - return True - - -def _int_and_mixed_shape(tensor): - shape = mixed_shape(tensor) - int_shape = tuple([i if isinstance(i, int) else None for i in shape]) - - return shape, int_shape diff --git a/prototorch/functions/initializers.py b/prototorch/functions/initializers.py deleted file mode 100644 index 345b723..0000000 --- a/prototorch/functions/initializers.py +++ /dev/null @@ -1,107 +0,0 @@ -"""ProtoTorch initialization functions.""" - -from itertools import chain - -import torch - -INITIALIZERS = dict() - - -def register_initializer(function): - """Add the initializer to the registry.""" - INITIALIZERS[function.__name__] = function - return function - - -def labels_from(distribution, one_hot=True): - """Takes a distribution tensor and returns a labels tensor.""" - num_classes = distribution.shape[0] - llist = [[i] * n for i, n in zip(range(num_classes), distribution)] - # labels = [l for cl in llist for l in cl] # flatten the list of lists - flat_llist = list(chain(*llist)) # flatten label list with itertools.chain - plabels = torch.tensor(flat_llist, requires_grad=False) - if one_hot: - return torch.eye(num_classes)[plabels] - return plabels - - -@register_initializer -def ones(x_train, y_train, prototype_distribution, one_hot=True): - num_protos = torch.sum(prototype_distribution) - protos = torch.ones(num_protos, *x_train.shape[1:]) - plabels = labels_from(prototype_distribution, one_hot) - return protos, plabels - - -@register_initializer -def zeros(x_train, y_train, prototype_distribution, one_hot=True): - num_protos = torch.sum(prototype_distribution) - protos = torch.zeros(num_protos, *x_train.shape[1:]) - plabels = labels_from(prototype_distribution, one_hot) - return protos, plabels - - -@register_initializer -def rand(x_train, y_train, prototype_distribution, one_hot=True): - num_protos = torch.sum(prototype_distribution) - protos = torch.rand(num_protos, *x_train.shape[1:]) - plabels = labels_from(prototype_distribution, one_hot) - return protos, plabels - - -@register_initializer -def randn(x_train, y_train, prototype_distribution, one_hot=True): - num_protos = torch.sum(prototype_distribution) - protos = torch.randn(num_protos, *x_train.shape[1:]) - plabels = labels_from(prototype_distribution, one_hot) - return protos, plabels - - -@register_initializer -def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True): - num_protos = torch.sum(prototype_distribution) - pdim = x_train.shape[1] - protos = torch.empty(num_protos, pdim) - plabels = labels_from(prototype_distribution, one_hot) - for i, label in enumerate(plabels): - matcher = torch.eq(label.unsqueeze(dim=0), y_train) - if one_hot: - num_classes = y_train.size()[1] - matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) - xl = x_train[matcher] - mean_xl = torch.mean(xl, dim=0) - protos[i] = mean_xl - plabels = labels_from(prototype_distribution, one_hot=one_hot) - return protos, plabels - - -@register_initializer -def stratified_random(x_train, - y_train, - prototype_distribution, - one_hot=True, - epsilon=1e-7): - num_protos = torch.sum(prototype_distribution) - pdim = x_train.shape[1] - protos = torch.empty(num_protos, pdim) - plabels = labels_from(prototype_distribution, one_hot) - for i, label in enumerate(plabels): - matcher = torch.eq(label.unsqueeze(dim=0), y_train) - if one_hot: - num_classes = y_train.size()[1] - matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) - xl = x_train[matcher] - rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1) - random_xl = xl[rand_index] - protos[i] = random_xl + epsilon - plabels = labels_from(prototype_distribution, one_hot=one_hot) - return protos, plabels - - -def get_initializer(funcname): - """Deserialize the initializer.""" - if callable(funcname): - return funcname - if funcname in INITIALIZERS: - return INITIALIZERS.get(funcname) - raise NameError(f"Initializer {funcname} was not found.") diff --git a/prototorch/functions/losses.py b/prototorch/functions/losses.py deleted file mode 100644 index 249882a..0000000 --- a/prototorch/functions/losses.py +++ /dev/null @@ -1,94 +0,0 @@ -"""ProtoTorch loss functions.""" - -import torch - - -def _get_matcher(targets, labels): - """Returns a boolean tensor.""" - matcher = torch.eq(targets.unsqueeze(dim=1), labels) - if labels.ndim == 2: - # if the labels are one-hot vectors - num_classes = targets.size()[1] - matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) - return matcher - - -def _get_dp_dm(distances, targets, plabels, with_indices=False): - """Returns the d+ and d- values for a batch of distances.""" - matcher = _get_matcher(targets, plabels) - not_matcher = torch.bitwise_not(matcher) - - inf = torch.full_like(distances, fill_value=float("inf")) - d_matching = torch.where(matcher, distances, inf) - d_unmatching = torch.where(not_matcher, distances, inf) - dp = torch.min(d_matching, dim=-1, keepdim=True) - dm = torch.min(d_unmatching, dim=-1, keepdim=True) - if with_indices: - return dp, dm - return dp.values, dm.values - - -def glvq_loss(distances, target_labels, prototype_labels): - """GLVQ loss function with support for one-hot labels.""" - dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) - mu = (dp - dm) / (dp + dm) - return mu - - -def lvq1_loss(distances, target_labels, prototype_labels): - """LVQ1 loss function with support for one-hot labels. - - See Section 4 [Sado&Yamada] - https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf - """ - dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) - mu = dp - mu[dp > dm] = -dm[dp > dm] - return mu - - -def lvq21_loss(distances, target_labels, prototype_labels): - """LVQ2.1 loss function with support for one-hot labels. - - See Section 4 [Sado&Yamada] - https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf - """ - dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) - mu = dp - dm - - return mu - - -# Probabilistic -def _get_class_probabilities(probabilities, targets, prototype_labels): - # Create Label Mapping - uniques = prototype_labels.unique(sorted=True).tolist() - key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} - - target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist()))) - - whole = probabilities.sum(dim=1) - correct = probabilities[torch.arange(len(probabilities)), target_indices] - wrong = whole - correct - - return whole, correct, wrong - - -def nllr_loss(probabilities, targets, prototype_labels): - """Compute the Negative Log-Likelihood Ratio loss.""" - _, correct, wrong = _get_class_probabilities(probabilities, targets, - prototype_labels) - - likelihood = correct / wrong - log_likelihood = torch.log(likelihood) - return -1.0 * log_likelihood - - -def rslvq_loss(probabilities, targets, prototype_labels): - """Compute the Robust Soft Learning Vector Quantization (RSLVQ) loss.""" - whole, correct, _ = _get_class_probabilities(probabilities, targets, - prototype_labels) - - likelihood = correct / whole - log_likelihood = torch.log(likelihood) - return -1.0 * log_likelihood diff --git a/prototorch/functions/normalization.py b/prototorch/functions/normalization.py deleted file mode 100644 index 96980b8..0000000 --- a/prototorch/functions/normalization.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import absolute_import, division, print_function - -import torch - - -def orthogonalization(tensors): - r""" Orthogonalization of a given tensor via polar decomposition. - """ - u, _, v = torch.svd(tensors, compute_uv=True) - u_shape = tuple(list(u.shape)) - v_shape = tuple(list(v.shape)) - - # reshape to (num x N x M) - u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1])) - v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1])) - - out = u @ v.permute([0, 2, 1]) - - out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], )) - - return out - - -def trace_normalization(tensors): - r""" Trace normalization - """ - epsilon = torch.tensor([1e-10], dtype=torch.float64) - # Scope trace_normalization - constant = torch.trace(tensors) - - if epsilon != 0: - constant = torch.max(constant, epsilon) - - return tensors / constant diff --git a/prototorch/functions/pooling.py b/prototorch/functions/pooling.py deleted file mode 100644 index 6dd427e..0000000 --- a/prototorch/functions/pooling.py +++ /dev/null @@ -1,80 +0,0 @@ -"""ProtoTorch pooling functions.""" - -from typing import Callable - -import torch - - -def stratify_with(values: torch.Tensor, - labels: torch.LongTensor, - fn: Callable, - fill_value: float = 0.0) -> (torch.Tensor): - """Apply an arbitrary stratification strategy on the columns on `values`. - - The outputs correspond to sorted labels. - """ - clabels = torch.unique(labels, dim=0, sorted=True) - num_classes = clabels.size()[0] - if values.size()[1] == num_classes: - # skip if stratification is trivial - return values - batch_size = values.size()[0] - winning_values = torch.zeros(num_classes, batch_size, device=labels.device) - filler = torch.full_like(values.T, fill_value=fill_value) - for i, cl in enumerate(clabels): - matcher = torch.eq(labels.unsqueeze(dim=1), cl) - if labels.ndim == 2: - # if the labels are one-hot vectors - matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) - cdists = torch.where(matcher, values.T, filler).T - winning_values[i] = fn(cdists) - if labels.ndim == 2: - # Transpose to return with `batch_size` first and - # reverse the columns to fix the ordering of the classes - return torch.flip(winning_values.T, dims=(1, )) - - return winning_values.T # return with `batch_size` first - - -def stratified_sum_pooling(values: torch.Tensor, - labels: torch.LongTensor) -> (torch.Tensor): - """Group-wise sum.""" - winning_values = stratify_with( - values, - labels, - fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(), - fill_value=0.0) - return winning_values - - -def stratified_min_pooling(values: torch.Tensor, - labels: torch.LongTensor) -> (torch.Tensor): - """Group-wise minimum.""" - winning_values = stratify_with( - values, - labels, - fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(), - fill_value=float("inf")) - return winning_values - - -def stratified_max_pooling(values: torch.Tensor, - labels: torch.LongTensor) -> (torch.Tensor): - """Group-wise maximum.""" - winning_values = stratify_with( - values, - labels, - fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(), - fill_value=-1.0 * float("inf")) - return winning_values - - -def stratified_prod_pooling(values: torch.Tensor, - labels: torch.LongTensor) -> (torch.Tensor): - """Group-wise maximum.""" - winning_values = stratify_with( - values, - labels, - fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(), - fill_value=1.0) - return winning_values diff --git a/prototorch/functions/similarities.py b/prototorch/functions/similarities.py deleted file mode 100644 index cc91c78..0000000 --- a/prototorch/functions/similarities.py +++ /dev/null @@ -1,18 +0,0 @@ -"""ProtoTorch similarity functions.""" - -import torch - - -def cosine_similarity(x, y): - """Compute the cosine similarity between :math:`x` and :math:`y`. - - Expected dimension of x is 2. - Expected dimension of y is 2. - """ - norm_x = x.pow(2).sum(1).sqrt() - norm_y = y.pow(2).sum(1).sqrt() - norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T - epsilon = torch.finfo(norm_mat.dtype).eps - norm_mat.clamp_(min=epsilon) - similarities = (x @ y.T) / norm_mat - return similarities diff --git a/prototorch/functions/transforms.py b/prototorch/functions/transforms.py deleted file mode 100644 index 334d382..0000000 --- a/prototorch/functions/transforms.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch - - -# Functions -def gaussian(distances, variance): - return torch.exp(-(distances * distances) / (2 * variance)) - - -def rank_scaled_gaussian(distances, lambd): - order = torch.argsort(distances, dim=1) - ranks = torch.argsort(order, dim=1) - - return torch.exp(-torch.exp(-ranks / lambd) * distances) - - -# Modules -class GaussianPrior(torch.nn.Module): - def __init__(self, variance): - super().__init__() - self.variance = variance - - def forward(self, distances): - return gaussian(distances, self.variance) - - -class RankScaledGaussianPrior(torch.nn.Module): - def __init__(self, lambd): - super().__init__() - self.lambd = lambd - - def forward(self, distances): - return rank_scaled_gaussian(distances, self.lambd) diff --git a/prototorch/modules/__init__.py b/prototorch/modules/__init__.py deleted file mode 100644 index fc7ab87..0000000 --- a/prototorch/modules/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""ProtoTorch modules.""" - -from .competitions import * -from .pooling import * -from .wrappers import LambdaLayer, LossLayer From d26a626677e6350fbc83f0cff9a821313dd0092d Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 20:48:39 +0200 Subject: [PATCH 19/43] Temporarily remove tangent distance --- prototorch/core/distances.py | 163 ----------------------------------- 1 file changed, 163 deletions(-) diff --git a/prototorch/core/distances.py b/prototorch/core/distances.py index 0782769..c19a8dc 100644 --- a/prototorch/core/distances.py +++ b/prototorch/core/distances.py @@ -1,15 +1,7 @@ """ProtoTorch distances""" -import numpy as np import torch -# from prototorch.functions.helper import ( -# _check_shapes, -# _int_and_mixed_shape, -# equal_int_shape, -# get_flat, -# ) - def squared_euclidean_distance(x, y): r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`. @@ -102,160 +94,5 @@ def lomega_distance(x, y, omegas): return distances -# def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10): -# r"""Computes an euclidean distances matrix given two distinct vectors. -# last dimension must be the vector dimension! -# compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction! - -# - ``x.shape = (number_of_x_vectors, vector_dim)`` -# - ``y.shape = (number_of_y_vectors, vector_dim)`` - -# output: matrix of distances (number_of_x_vectors, number_of_y_vectors) -# """ -# for tensor in [x, y]: -# if tensor.ndim != 2: -# raise ValueError( -# "The tensor dimension must be two. You provide: tensor.ndim=" + -# str(tensor.ndim) + ".") -# if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]): -# raise ValueError( -# "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]=" -# + str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" + -# str(tuple(y.shape)[1]) + ".") - -# y = torch.transpose(y) - -# diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) + -# torch.sum(y**2, axis=0, keepdims=True)) - -# if not squared: -# if epsilon == 0: -# diss = torch.sqrt(diss) -# else: -# diss = torch.sqrt(torch.max(diss, epsilon)) - -# return diss - -# def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10): -# r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews - -# For more info about Tangen distances see - -# DOI:10.1109/IJCNN.2016.7727534. - -# The subspaces is always assumed as transposed and must be orthogonal! -# For local non sparse signals subspaces must be provided! - -# - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN -# - shape(protos): proto_number x dim1 x dim2 x ... x dimN -# - shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape) - -# subspace should be orthogonalized -# Pytorch implementation of Sascha Saralajew's tensorflow code. -# Translation by Christoph Raab -# """ -# signal_shape, signal_int_shape = _int_and_mixed_shape(signals) -# proto_shape, proto_int_shape = _int_and_mixed_shape(protos) -# subspace_int_shape = tuple(subspaces.shape) - -# # check if the shapes are correct -# _check_shapes(signal_int_shape, proto_int_shape) - -# atom_axes = list(range(3, len(signal_int_shape))) -# # for sparse signals, we use the memory efficient implementation -# if signal_int_shape[1] == 1: -# signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])]) - -# if len(atom_axes) > 1: -# protos = torch.reshape(protos, [proto_shape[0], -1]) - -# if subspaces.ndim == 2: -# # clean solution without map if the matrix_scope is global -# projectors = torch.eye(subspace_int_shape[-2]) - torch.dot( -# subspaces, torch.transpose(subspaces)) - -# projected_signals = torch.dot(signals, projectors) -# projected_protos = torch.dot(protos, projectors) - -# diss = euclidean_distance_matrix(projected_signals, -# projected_protos, -# squared=squared, -# epsilon=epsilon) - -# diss = torch.reshape( -# diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) - -# return torch.permute(diss, [0, 2, 1]) - -# else: - -# # no solution without map possible --> memory efficient but slow! -# projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm( -# subspaces, -# subspaces) # K.batch_dot(subspaces, subspaces, [2, 2]) - -# projected_protos = (protos @ subspaces -# ).T # K.batch_dot(projectors, protos, [1, 1])) - -# def projected_norm(projector): -# return torch.sum(torch.dot(signals, projector)**2, axis=1) - -# diss = (torch.transpose(map(projected_norm, projectors)) - -# 2 * torch.dot(signals, projected_protos) + -# torch.sum(projected_protos**2, axis=0, keepdims=True)) - -# if not squared: -# if epsilon == 0: -# diss = torch.sqrt(diss) -# else: -# diss = torch.sqrt(torch.max(diss, epsilon)) - -# diss = torch.reshape( -# diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) - -# return torch.permute(diss, [0, 2, 1]) - -# else: -# signals = signals.permute([0, 2, 1] + atom_axes) - -# diff = signals - protos - -# # global tangent space -# if subspaces.ndim == 2: -# # Scope Projectors -# projectors = subspaces # - -# # Scope: Tangentspace Projections -# diff = torch.reshape( -# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) -# projected_diff = diff @ projectors -# projected_diff = torch.reshape( -# projected_diff, -# (signal_shape[0], signal_shape[2], signal_shape[1]) + -# signal_shape[3:], -# ) - -# diss = torch.norm(projected_diff, 2, dim=-1) -# return diss.permute([0, 2, 1]) - -# # local tangent spaces -# else: -# # Scope: Calculate Projectors -# projectors = subspaces - -# # Scope: Tangentspace Projections -# diff = torch.reshape( -# diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) -# diff = diff.permute([1, 0, 2]) -# projected_diff = torch.bmm(diff, projectors) -# projected_diff = torch.reshape( -# projected_diff, -# (signal_shape[1], signal_shape[0], signal_shape[2]) + -# signal_shape[3:], -# ) - -# diss = torch.norm(projected_diff, 2, dim=-1) -# return diss.permute([1, 0, 2]).squeeze(-1) - # Aliases sed = squared_euclidean_distance From 935d9fa7adde1e315c1cb1c954031e60f4106e07 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sat, 12 Jun 2021 20:50:04 +0200 Subject: [PATCH 20/43] Add similarities --- prototorch/core/__init__.py | 3 +++ prototorch/core/similarities.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 prototorch/core/similarities.py diff --git a/prototorch/core/__init__.py b/prototorch/core/__init__.py index 4badc95..c205dfa 100644 --- a/prototorch/core/__init__.py +++ b/prototorch/core/__init__.py @@ -2,5 +2,8 @@ from .competitions import * from .components import * +from .distances import * from .initializers import * from .losses import * +from .pooling import * +from .similarities import * diff --git a/prototorch/core/similarities.py b/prototorch/core/similarities.py new file mode 100644 index 0000000..6125f8e --- /dev/null +++ b/prototorch/core/similarities.py @@ -0,0 +1,19 @@ +"""ProtoTorch similarities.""" + +import torch + + +def cosine_similarity(x, y): + """Compute the cosine similarity between :math:`x` and :math:`y`. + + Expected dimension of x is 2. + Expected dimension of y is 2. + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + norm_x = x.pow(2).sum(1).sqrt() + norm_y = y.pow(2).sum(1).sqrt() + norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T + epsilon = torch.finfo(norm_mat.dtype).eps + norm_mat.clamp_(min=epsilon) + similarities = (x @ y.T) / norm_mat + return similarities From 84e08955f782e388550168843b710d2a213e471f Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sun, 13 Jun 2021 17:02:57 +0000 Subject: [PATCH 21/43] Check if build passes with python3.9 --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 77f8083..6320573 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ dist: bionic sudo: false language: python -python: 3.8 +python: 3.9 cache: directories: - "$HOME/.cache/pip" From 2af1da7f23705ff0162bc83d27b3575f88aef455 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sun, 13 Jun 2021 22:54:29 +0000 Subject: [PATCH 22/43] Add standalone labels module --- prototorch/core/components.py | 158 +++++++++++++++++++++++++--------- tests/test_core.py | 34 ++++++++ 2 files changed, 149 insertions(+), 43 deletions(-) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index 53555af..d2b0f40 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -41,6 +41,24 @@ def validate_reasonings_initializer(initializer): return validate_initializer(initializer, AbstractReasoningsInitializer) +def gencat(ins, attr, init, *iargs, **ikwargs): + """Generate new items and concatenate with existing items.""" + new_items = init.generate(*iargs, **ikwargs) + if hasattr(ins, attr): + items = torch.cat([getattr(ins, attr), new_items]) + else: + items = new_items + return items, new_items + + +def removeind(ins, attr, indices): + """Remove items at specified indices.""" + mask = torch.ones(len(ins), dtype=torch.bool) + mask[indices] = False + items = getattr(ins, attr)[mask] + return items, mask + + class AbstractComponents(torch.nn.Module): """Abstract class for all components modules.""" @property @@ -57,7 +75,10 @@ class AbstractComponents(torch.nn.Module): self.register_parameter("_components", Parameter(components)) def extra_repr(self): - return f"(components): (shape: {tuple(self._components.shape)})" + return f"components: (shape: {tuple(self._components.shape)})" + + def __len__(self): + return self.num_components class Components(AbstractComponents): @@ -67,24 +88,18 @@ class Components(AbstractComponents): super().__init__(**kwargs) self.add_components(num_components, initializer) - def add_components(self, num: int, + def add_components(self, num_components: int, initializer: AbstractComponentsInitializer): - """Add new components.""" + """Generate and add new components.""" assert validate_components_initializer(initializer) - new_components = initializer.generate(num) - # Register - if hasattr(self, "_components"): - _components = torch.cat([self._components, new_components]) - else: - _components = new_components + _components, new_components = gencat(self, "_components", initializer, + num_components) self._register_components(_components) return new_components def remove_components(self, indices): """Remove components at specified indices.""" - mask = torch.ones(self.num_components, dtype=torch.bool) - mask[indices] = False - _components = self._components[mask] + _components, mask = removeind(self, "_components", indices) self._register_components(_components) return mask @@ -93,19 +108,90 @@ class Components(AbstractComponents): return self._components +class AbstractLabels(torch.nn.Module): + """Abstract class for all labels modules.""" + @property + def labels(self): + return self._labels + + @property + def num_labels(self): + return len(self.labels) + + @property + def unique_labels(self): + return torch.unique(self._labels) + + @property + def num_unique(self): + return len(self.unique_labels) + + @property + def distribution(self): + unique, counts = torch.unique(self._labels, + sorted=True, + return_counts=True) + return dict(zip(unique.tolist(), counts.tolist())) + + def _register_labels(self, labels): + self.register_buffer("_labels", labels) + + def extra_repr(self): + r = f"num_labels: {self.num_labels}, num_unique: {self.num_unique}" + if len(self.distribution) < 11: # avoid lengthy representations + d = self.distribution + unique, counts = list(d.keys()), list(d.values()) + r += f", unique: {unique}, counts: {counts}" + return r + + def __len__(self): + return self.num_labels + + +class Labels(AbstractLabels): + """A set of standalone labels.""" + def __init__(self, + distribution: Union[dict, list, tuple], + initializer: AbstractLabelsInitializer = LabelsInitializer(), + **kwargs): + super().__init__(**kwargs) + self.add_labels(distribution, initializer) + + def add_labels( + self, + distribution: Union[dict, tuple, list], + initializer: AbstractLabelsInitializer = LabelsInitializer()): + """Generate and add new labels.""" + assert validate_labels_initializer(initializer) + _labels, new_labels = gencat(self, "_labels", initializer, + distribution) + self._register_labels(_labels) + return new_labels + + def remove_labels(self, indices): + """Remove labels at specified indices.""" + _labels, mask = removeind(self, "_labels", indices) + self._register_labels(_labels) + return mask + + class LabeledComponents(AbstractComponents): """A set of adaptable components and corresponding unadaptable labels.""" - def __init__(self, distribution: Union[dict, list, tuple], - components_initializer: AbstractComponentsInitializer, - labels_initializer: AbstractLabelsInitializer, **kwargs): + def __init__( + self, + distribution: Union[dict, list, tuple], + components_initializer: AbstractComponentsInitializer, + labels_initializer: AbstractLabelsInitializer = LabelsInitializer( + ), + **kwargs): super().__init__(**kwargs) self.add_components(distribution, components_initializer, labels_initializer) @property - def component_labels(self): - """Tensor containing the component tensors.""" - return self._labels.detach() + def labels(self): + """Tensor containing the component labels.""" + return self._labels def _register_labels(self, labels): self.register_buffer("_labels", labels) @@ -115,42 +201,28 @@ class LabeledComponents(AbstractComponents): distribution, components_initializer, labels_initializer: AbstractLabelsInitializer = LabelsInitializer()): - # Checks + """Generate and add new components and labels.""" assert validate_components_initializer(components_initializer) assert validate_labels_initializer(labels_initializer) - - distribution = parse_distribution(distribution) - - # Generate new components if isinstance(components_initializer, ClassAwareCompInitializer): - new_components = components_initializer.generate(distribution) + cikwargs = dict(distribution=distribution) else: + distribution = parse_distribution(distribution) num_components = sum(distribution.values()) - new_components = components_initializer.generate(num_components) - - # Generate new labels - new_labels = labels_initializer.generate(distribution) - - # Register - if hasattr(self, "_components"): - _components = torch.cat([self._components, new_components]) - else: - _components = new_components - if hasattr(self, "_labels"): - _labels = torch.cat([self._labels, new_labels]) - else: - _labels = new_labels + cikwargs = dict(num_components=num_components) + _components, new_components = gencat(self, "_components", + components_initializer, + **cikwargs) + _labels, new_labels = gencat(self, "_labels", labels_initializer, + distribution) self._register_components(_components) self._register_labels(_labels) - return new_components, new_labels def remove_components(self, indices): """Remove components and labels at specified indices.""" - mask = torch.ones(self.num_components, dtype=torch.bool) - mask[indices] = False - _components = self._components[mask] - _labels = self._labels[mask] + _components, mask = removeind(self, "_components", indices) + _labels, mask = removeind(self, "_labels", indices) self._register_components(_components) self._register_labels(_labels) return mask diff --git a/tests/test_core.py b/tests/test_core.py index d2496c8..191569e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -157,6 +157,40 @@ def test_components_zeros_init(): assert torch.allclose(c.components, torch.zeros(3, 2)) +def test_labeled_components_dict_init(): + c = pt.components.LabeledComponents({0: 3}, pt.initializers.OCI(2)) + assert torch.allclose(c.components, torch.ones(3, 2)) + assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long)) + + +def test_labeled_components_list_init(): + c = pt.components.LabeledComponents([3], pt.initializers.OCI(2)) + assert torch.allclose(c.components, torch.ones(3, 2)) + assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long)) + + +def test_labeled_components_tuple_init(): + c = pt.components.LabeledComponents({0: 1, 1: 2}, pt.initializers.OCI(2)) + assert torch.allclose(c.components, torch.ones(3, 2)) + assert torch.allclose(c.labels, torch.LongTensor([0, 1, 1])) + + +# Labels +def test_standalone_labels_dict_init(): + l = pt.components.Labels({0: 3}) + assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long)) + + +def test_standalone_labels_list_init(): + l = pt.components.Labels([3]) + assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long)) + + +def test_standalone_labels_tuple_init(): + l = pt.components.Labels({0: 1, 1: 2}) + assert torch.allclose(l.labels, torch.LongTensor([0, 1, 1])) + + # Losses def test_glvq_loss_int_labels(): d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) From 6ad665f8c248b4e7f9ece5f64d4c13e0f225d4b0 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sun, 13 Jun 2021 23:04:07 +0000 Subject: [PATCH 23/43] [REFACTOR] Simplify initializer validation --- prototorch/core/components.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index d2b0f40..a243318 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -29,18 +29,6 @@ def validate_initializer(initializer, instanceof): return True -def validate_components_initializer(initializer): - return validate_initializer(initializer, AbstractComponentsInitializer) - - -def validate_labels_initializer(initializer): - return validate_initializer(initializer, AbstractLabelsInitializer) - - -def validate_reasonings_initializer(initializer): - return validate_initializer(initializer, AbstractReasoningsInitializer) - - def gencat(ins, attr, init, *iargs, **ikwargs): """Generate new items and concatenate with existing items.""" new_items = init.generate(*iargs, **ikwargs) @@ -91,7 +79,7 @@ class Components(AbstractComponents): def add_components(self, num_components: int, initializer: AbstractComponentsInitializer): """Generate and add new components.""" - assert validate_components_initializer(initializer) + assert validate_initializer(initializer, AbstractComponentsInitializer) _components, new_components = gencat(self, "_components", initializer, num_components) self._register_components(_components) @@ -162,7 +150,7 @@ class Labels(AbstractLabels): distribution: Union[dict, tuple, list], initializer: AbstractLabelsInitializer = LabelsInitializer()): """Generate and add new labels.""" - assert validate_labels_initializer(initializer) + assert validate_initializer(initializer, AbstractLabelsInitializer) _labels, new_labels = gencat(self, "_labels", initializer, distribution) self._register_labels(_labels) @@ -202,8 +190,10 @@ class LabeledComponents(AbstractComponents): components_initializer, labels_initializer: AbstractLabelsInitializer = LabelsInitializer()): """Generate and add new components and labels.""" - assert validate_components_initializer(components_initializer) - assert validate_labels_initializer(labels_initializer) + assert validate_initializer(components_initializer, + AbstractComponentsInitializer) + assert validate_initializer(labels_initializer, + AbstractLabelsInitializer) if isinstance(components_initializer, ClassAwareCompInitializer): cikwargs = dict(distribution=distribution) else: @@ -270,8 +260,10 @@ class ReasoningComponents(AbstractComponents): def add_components(self, distribution, components_initializer, reasonings_initializer: AbstractReasoningsInitializer): # Checks - assert validate_components_initializer(components_initializer) - assert validate_reasonings_initializer(reasonings_initializer) + assert validate_initializer(components_initializer, + AbstractComponentsInitializer) + assert validate_initializer(reasonings_initializer, + AbstractReasoningsInitializer) distribution = parse_distribution(distribution) From d2d6f31e7b6f3b3c406b35fc903f0e5c495a964b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 14 Jun 2021 14:44:36 +0200 Subject: [PATCH 24/43] [REFACTOR] Simplify ReasoningComponents --- prototorch/core/components.py | 62 ++++++++++++++--------------------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index a243318..f67fc1b 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -17,6 +17,7 @@ from .initializers import ( def validate_initializer(initializer, instanceof): + """Check if the initializer is valid.""" if not isinstance(initializer, instanceof): emsg = f"`initializer` has to be an instance " \ f"of some subtype of {instanceof}. " \ @@ -47,6 +48,17 @@ def removeind(ins, attr, indices): return items, mask +def get_cikwargs(init, distribution): + """Return appropriate key-word arguments for a component initializer.""" + if isinstance(init, ClassAwareCompInitializer): + cikwargs = dict(distribution=distribution) + else: + distribution = parse_distribution(distribution) + num_components = sum(distribution.values()) + cikwargs = dict(num_components=num_components) + return cikwargs + + class AbstractComponents(torch.nn.Module): """Abstract class for all components modules.""" @property @@ -194,12 +206,7 @@ class LabeledComponents(AbstractComponents): AbstractComponentsInitializer) assert validate_initializer(labels_initializer, AbstractLabelsInitializer) - if isinstance(components_initializer, ClassAwareCompInitializer): - cikwargs = dict(distribution=distribution) - else: - distribution = parse_distribution(distribution) - num_components = sum(distribution.values()) - cikwargs = dict(num_components=num_components) + cikwargs = get_cikwargs(components_initializer, distribution) _components, new_components = gencat(self, "_components", components_initializer, **cikwargs) @@ -259,47 +266,28 @@ class ReasoningComponents(AbstractComponents): def add_components(self, distribution, components_initializer, reasonings_initializer: AbstractReasoningsInitializer): - # Checks + """Generate and add new components and reasonings.""" assert validate_initializer(components_initializer, AbstractComponentsInitializer) assert validate_initializer(reasonings_initializer, AbstractReasoningsInitializer) - - distribution = parse_distribution(distribution) - - # Generate new components - if isinstance(components_initializer, ClassAwareCompInitializer): - new_components = components_initializer.generate(distribution) - else: - num_components = sum(distribution.values()) - new_components = components_initializer.generate(num_components) - - # Generate new reasonings - new_reasonings = reasonings_initializer.generate(distribution) - - # Register - if hasattr(self, "_components"): - _components = torch.cat([self._components, new_components]) - else: - _components = new_components - if hasattr(self, "_reasonings"): - _reasonings = torch.cat([self._reasonings, new_reasonings]) - else: - _reasonings = new_reasonings + cikwargs = get_cikwargs(components_initializer, distribution) + _components, new_components = gencat(self, "_components", + components_initializer, + **cikwargs) + _reasonings, new_reasonings = gencat(self, "_reasonings", + reasonings_initializer, + distribution) self._register_components(_components) self._register_reasonings(_reasonings) - return new_components, new_reasonings def remove_components(self, indices): - """Remove components and labels at specified indices.""" - mask = torch.ones(self.num_components, dtype=torch.bool) - mask[indices] = False - _components = self._components[mask] - # TODO - # _reasonings = self._reasonings[mask] + """Remove components and reasonings at specified indices.""" + _components, mask = removeind(self, "_components", indices) + _reasonings, mask = removeind(self, "_reasonings", indices) self._register_components(_components) - # self._register_reasonings(_reasonings) + self._register_reasonings(_reasonings) return mask def forward(self): From 668c9a1fb79c26c70492e4884c0d667df8910da4 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 14 Jun 2021 14:45:14 +0200 Subject: [PATCH 25/43] [TEST] Add more tests --- prototorch/core/initializers.py | 10 ++----- tests/test_core.py | 52 +++++++++++++++++++++++++-------- 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index ba48ffd..b361f35 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -81,14 +81,9 @@ class ClassAwareCompInitializer(AbstractComponentsInitializer): def generate(self, distribution: Union[dict, list, tuple]): distribution = parse_distribution(distribution) - initializers = { - k: self.subinit_type(self.data[self.targets == k]) - for k in distribution.keys() - } components = torch.tensor([]) for k, v in distribution.items(): stratified_data = self.data[self.targets == k] - # skip transform here initializer = self.subinit_type( stratified_data, noise=self.noise, @@ -157,13 +152,14 @@ class UniformCompInitializer(OnesCompInitializer): class RandomNormalCompInitializer(OnesCompInitializer): """Generate components by sampling from a standard normal distribution.""" - def __init__(self, shape, scale=1.0): + def __init__(self, shape, shift=0.0, scale=1.0): super().__init__(shape) + self.shift = shift self.scale = scale def generate(self, num_components: int): ones = super().generate(num_components) - components = self.scale * torch.randn_like(ones) + components = self.scale * (torch.randn_like(ones) + self.shift) return components diff --git a/tests/test_core.py b/tests/test_core.py index 191569e..1a1327e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -73,6 +73,22 @@ def test_fill_value_comp_generate(): assert torch.allclose(components, torch.zeros(3, 2)) +def test_uniform_comp_generate_min_max_bound(): + c = pt.initializers.UniformCompInitializer(2, -1.0, 1.0) + components = c.generate(num_components=1024) + assert components.min() >= -1.0 + assert components.max() <= 1.0 + + +def test_random_comp_generate_mean(): + c = pt.initializers.RandomNormalCompInitializer(2, -1.0) + components = c.generate(num_components=1024) + assert torch.allclose(components.mean(), + torch.tensor(-1.0), + rtol=1e-05, + atol=1e-01) + + def test_comp_generate_0_components(): c = pt.initializers.ZerosCompInitializer(2) _ = c.generate(num_components=0) @@ -294,7 +310,8 @@ class TestCompetitions(unittest.TestCase): def test_wtac(self): d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) labels = torch.tensor([0, 1, 2, 3]) - actual = pt.competitions.wtac(d, labels) + competition_layer = pt.competitions.WTAC() + actual = competition_layer(d, labels) desired = torch.tensor([2, 0]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -304,7 +321,8 @@ class TestCompetitions(unittest.TestCase): def test_wtac_unequal_dist(self): d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]]) labels = torch.tensor([0, 1, 1]) - actual = pt.competitions.wtac(d, labels) + competition_layer = pt.competitions.WTAC() + actual = competition_layer(d, labels) desired = torch.tensor([0, 1]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -314,7 +332,8 @@ class TestCompetitions(unittest.TestCase): def test_wtac_one_hot(self): d = torch.tensor([[1.99, 3.01], [3.0, 2.01]]) labels = torch.tensor([[0, 1], [1, 0]]) - actual = pt.competitions.wtac(d, labels) + competition_layer = pt.competitions.WTAC() + actual = competition_layer(d, labels) desired = torch.tensor([[0, 1], [1, 0]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -324,7 +343,8 @@ class TestCompetitions(unittest.TestCase): def test_knnc_k1(self): d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) labels = torch.tensor([0, 1, 2, 3]) - actual = pt.competitions.knnc(d, labels, k=1) + competition_layer = pt.competitions.KNNC(k=1) + actual = competition_layer(d, labels) desired = torch.tensor([2, 0]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -343,7 +363,8 @@ class TestPooling(unittest.TestCase): def test_stratified_min(self): d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) labels = torch.tensor([0, 0, 1, 2]) - actual = pt.pooling.stratified_min_pooling(d, labels) + pooling_layer = pt.pooling.StratifiedMinPooling() + actual = pooling_layer(d, labels) desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -354,7 +375,8 @@ class TestPooling(unittest.TestCase): d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) labels = torch.tensor([0, 0, 1, 2]) labels = torch.eye(3)[labels] - actual = pt.pooling.stratified_min_pooling(d, labels) + pooling_layer = pt.pooling.StratifiedMinPooling() + actual = pooling_layer(d, labels) desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -364,7 +386,8 @@ class TestPooling(unittest.TestCase): def test_stratified_min_trivial(self): d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]]) labels = torch.tensor([0, 1, 2]) - actual = pt.pooling.stratified_min_pooling(d, labels) + pooling_layer = pt.pooling.StratifiedMinPooling() + actual = pooling_layer(d, labels) desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -374,7 +397,8 @@ class TestPooling(unittest.TestCase): def test_stratified_max(self): d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) labels = torch.tensor([0, 0, 3, 2, 0]) - actual = pt.pooling.stratified_max_pooling(d, labels) + pooling_layer = pt.pooling.StratifiedMaxPooling() + actual = pooling_layer(d, labels) desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -385,7 +409,8 @@ class TestPooling(unittest.TestCase): d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) labels = torch.tensor([0, 0, 2, 1, 0]) labels = torch.nn.functional.one_hot(labels, num_classes=3) - actual = pt.pooling.stratified_max_pooling(d, labels) + pooling_layer = pt.pooling.StratifiedMaxPooling() + actual = pooling_layer(d, labels) desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -395,7 +420,8 @@ class TestPooling(unittest.TestCase): def test_stratified_sum(self): d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) labels = torch.LongTensor([0, 0, 1, 2]) - actual = pt.pooling.stratified_sum_pooling(d, labels) + pooling_layer = pt.pooling.StratifiedSumPooling() + actual = pooling_layer(d, labels) desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -406,7 +432,8 @@ class TestPooling(unittest.TestCase): d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) labels = torch.tensor([0, 0, 1, 2]) labels = torch.eye(3)[labels] - actual = pt.pooling.stratified_sum_pooling(d, labels) + pooling_layer = pt.pooling.StratifiedSumPooling() + actual = pooling_layer(d, labels) desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, @@ -416,7 +443,8 @@ class TestPooling(unittest.TestCase): def test_stratified_prod(self): d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) labels = torch.tensor([0, 0, 3, 2, 0]) - actual = pt.pooling.stratified_prod_pooling(d, labels) + pooling_layer = pt.pooling.StratifiedProdPooling() + actual = pooling_layer(d, labels) desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]]) mismatch = np.testing.assert_array_almost_equal(actual, desired, From 083cc929be7e1c7120bc12fa499e1cae2d9d4852 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 14 Jun 2021 17:19:45 +0200 Subject: [PATCH 26/43] [REFACTOR] Add reasonings initializers --- prototorch/core/initializers.py | 98 ++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 19 deletions(-) diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index b361f35..da0341d 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -205,10 +205,7 @@ class AbstractLabelsInitializer(ABC): class LabelsInitializer(AbstractLabelsInitializer): - """Generate labels with `self.distribution`.""" - def __init__(self, override_labels: list = []): - self.override_labels = override_labels - + """Generate labels from `distribution`.""" def generate(self, distribution: Union[dict, list, tuple]): distribution = parse_distribution(distribution) labels = [] @@ -218,25 +215,79 @@ class LabelsInitializer(AbstractLabelsInitializer): return labels -# Reasonings -class AbstractReasoningsInitializer(ABC): - """Abstract class for all reasonings initializers.""" - @abstractmethod - def generate(self, distribution: Union[dict, list, tuple]): - ... - - -class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): - """Generate labels with `self.distribution`.""" +class OneHotLabelsInitializer(LabelsInitializer): + """Generate one-hot-encoded labels from `distribution`.""" def generate(self, distribution: Union[dict, list, tuple]): distribution = parse_distribution(distribution) num_classes = len(distribution.keys()) + # this breaks if class labels are not [0,...,nclasses] + labels = torch.eye(num_classes)[super().generate(distribution)] + return labels + + +# Reasonings +class AbstractReasoningsInitializer(ABC): + """Abstract class for all reasonings initializers.""" + def __init__(self, components_first=True): + self.components_first = components_first + + def compute_shape(self, distribution): + distribution = parse_distribution(distribution) num_components = sum(distribution.values()) - assert num_classes == num_components - reasonings = torch.stack( - [torch.eye(num_classes), - torch.zeros(num_classes, num_classes)], - dim=0) + num_classes = len(distribution.keys()) + return (num_components, num_classes, 2) + + def generate_end_hook(self, reasonings): + if not self.components_first: + reasonings = reasonings.permute(2, 1, 0) + return reasonings + + @abstractmethod + def generate(self, distribution: Union[dict, list, tuple]): + ... + return generate_end_hook(...) + + +class ZerosReasoningsInitializer(AbstractReasoningsInitializer): + """Reasonings are all initialized with zeros.""" + def generate(self, distribution: Union[dict, list, tuple]): + shape = self.compute_shape(distribution) + reasonings = torch.zeros(*shape) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +class OnesReasoningsInitializer(AbstractReasoningsInitializer): + """Reasonings are all initialized with ones.""" + def generate(self, distribution: Union[dict, list, tuple]): + shape = self.compute_shape(distribution) + reasonings = torch.ones(*shape) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +class RandomReasoningsInitializer(AbstractReasoningsInitializer): + """Reasonings are randomly initialized.""" + def __init__(self, minimum=0.4, maximum=0.6, **kwargs): + super().__init__(**kwargs) + self.minimum = minimum + self.maximum = maximum + + def generate(self, distribution: Union[dict, list, tuple]): + shape = self.compute_shape(distribution) + reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): + """Each component reasons positively for exactly one class.""" + def generate(self, distribution: Union[dict, list, tuple]): + num_components, num_classes, _ = self.compute_shape(distribution) + A = OneHotLabelsInitializer().generate(distribution) + B = torch.zeros(num_components, num_classes) + reasonings = torch.stack([A, B]).permute(2, 1, 0) + reasonings = self.generate_end_hook(reasonings) return reasonings @@ -251,4 +302,13 @@ SCI = SelectionCompInitializer MCI = MeanCompInitializer SSCI = StratifiedSelectionCompInitializer SMCI = StratifiedMeanCompInitializer + +# Aliases - Labels +LI = LabelsInitializer +OHLI = OneHotLabelsInitializer + +# Aliases - Reasonings +ZRI = ZerosReasoningsInitializer +ORI = OnesReasoningsInitializer +RRI = RandomReasoningsInitializer PPRI = PurePositiveReasoningsInitializer From 92414755709d0d61d262e8fd45f814a03e82f457 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 14 Jun 2021 17:20:22 +0200 Subject: [PATCH 27/43] [REFACTOR] Refactor `parse_distribution` --- prototorch/utils/utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/prototorch/utils/utils.py b/prototorch/utils/utils.py index b2058cd..79c528b 100644 --- a/prototorch/utils/utils.py +++ b/prototorch/utils/utils.py @@ -23,10 +23,15 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100): return mesh, xx, yy -def parse_distribution( - user_distribution: Union[dict[int, int], dict[str, str], list[int], - tuple[int]] -) -> dict[int, int]: +def distribution_from_list(list_dist: list[int], clabels: list[int] = []): + clabels = clabels or list(range(len(list_dist))) + distribution = dict(zip(clabels, list_dist)) + return distribution + + +def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str], + list[int], tuple[int]], + clabels: list[int] = []) -> dict[int, int]: """Parse user-provided distribution. Return a dictionary with integer keys that represent the class labels and @@ -47,25 +52,20 @@ def parse_distribution( as one might expect. """ - def from_list(list_dist): - clabels = list(range(len(list_dist))) - distribution = dict(zip(clabels, list_dist)) - return distribution - if isinstance(user_distribution, dict): if "num_classes" in user_distribution.keys(): num_classes = int(user_distribution["num_classes"]) per_class = int(user_distribution["per_class"]) - return from_list([per_class] * num_classes) + return distribution_from_list([per_class] * num_classes, clabels) else: return user_distribution elif isinstance(user_distribution, tuple): assert len(user_distribution) == 2 num_classes, per_class = user_distribution num_classes, per_class = int(num_classes), int(per_class) - return from_list([per_class] * num_classes) + return distribution_from_list([per_class] * num_classes, clabels) elif isinstance(user_distribution, list): - return from_list(user_distribution) + return distribution_from_list(user_distribution, clabels) else: msg = f"`distribution` not understood." \ f"You have provided: {user_distribution}." From 549e6a10c177479544a4e83e0d8bea1f1be2c20f Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 14 Jun 2021 17:20:57 +0200 Subject: [PATCH 28/43] [TEST] Add tests for reasonings initializers --- tests/test_core.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index 1a1327e..757e678 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -41,6 +41,13 @@ def test_parse_distribution_list(): assert distribution == {0: 1, 1: 1, 2: 0, 3: 2} +def test_parse_distribution_custom_labels(): + distribution = [1, 1, 0, 2] + clabels = [1, 2, 5, 3] + distribution = parse_distribution(distribution, clabels) + assert distribution == {1: 1, 2: 1, 5: 0, 3: 2} + + # Components initializers def test_shape_aware_raises_error(): with pytest.raises(TypeError): @@ -147,6 +154,50 @@ def test_labels_init_from_tuple_illegal(): _ = l.generate(distribution=(1, 1, 1)) +# Reasonings initializers +def test_random_reasonings_init(): + r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8) + reasonings = r.generate(distribution=[0, 1]) + assert torch.numel(reasonings) == 1 * 2 * 2 + assert reasonings.min() >= 0.2 + assert reasonings.max() <= 0.8 + + +def test_zeros_reasonings_init(): + r = pt.initializers.ZerosReasoningsInitializer() + reasonings = r.generate(distribution=[0, 1]) + assert torch.allclose(reasonings, torch.zeros(1, 2, 2)) + + +def test_ones_reasonings_init(): + r = pt.initializers.ZerosReasoningsInitializer() + reasonings = r.generate(distribution=[1, 2, 3]) + assert torch.allclose(reasonings, torch.zeros(6, 3, 2)) + + +def test_random_reasonings_init_channels_not_first(): + r = pt.initializers.RandomReasoningsInitializer(components_first=False) + reasonings = r.generate(distribution=[1, 2]) + assert reasonings.shape[0] == 2 + assert reasonings.shape[-1] == 3 + + +def test_pure_positive_reasonings_init_one_per_class(): + r = pt.initializers.PurePositiveReasoningsInitializer( + components_first=False) + reasonings = r.generate(distribution=(4, 1)) + assert torch.allclose(reasonings[0], torch.eye(4)) + + +def test_pure_positive_reasonings_init_unrepresented_class(): + r = pt.initializers.PurePositiveReasoningsInitializer( + components_first=False) + reasonings = r.generate(distribution=[1, 0, 1]) + assert reasonings.shape[0] == 2 + assert reasonings.shape[1] == 2 + assert reasonings.shape[2] == 3 + + # Components def test_components_no_initializer(): with pytest.raises(TypeError): From fc9edeaa970aa76d79ce2031d8953aa2bcbfc675 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 14 Jun 2021 19:53:02 +0200 Subject: [PATCH 29/43] [FEATURE] Add more initializers --- prototorch/core/components.py | 8 +- prototorch/core/initializers.py | 280 +++++++++++++++++++++----------- prototorch/utils/utils.py | 12 +- 3 files changed, 198 insertions(+), 102 deletions(-) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index f67fc1b..d0155a7 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -8,10 +8,10 @@ from torch.nn.parameter import Parameter from ..utils import parse_distribution from .initializers import ( + AbstractClassAwareCompInitializer, AbstractComponentsInitializer, AbstractLabelsInitializer, AbstractReasoningsInitializer, - ClassAwareCompInitializer, LabelsInitializer, ) @@ -50,7 +50,7 @@ def removeind(ins, attr, indices): def get_cikwargs(init, distribution): """Return appropriate key-word arguments for a component initializer.""" - if isinstance(init, ClassAwareCompInitializer): + if isinstance(init, AbstractClassAwareCompInitializer): cikwargs = dict(distribution=distribution) else: distribution = parse_distribution(distribution) @@ -69,7 +69,7 @@ class AbstractComponents(torch.nn.Module): @property def components(self): """Detached Tensor containing the components.""" - return self._components.detach() + return self._components.detach().cpu() def _register_components(self, components): self.register_parameter("_components", Parameter(components)) @@ -259,7 +259,7 @@ class ReasoningComponents(AbstractComponents): Dimension NxCx2 """ - return self._reasonings.detach() + return self._reasonings.detach().cpu() def _register_reasonings(self, reasonings): self.register_parameter("_reasonings", Parameter(reasonings)) diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index da0341d..8e65823 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -1,5 +1,6 @@ """ProtoTorch code initializers""" +import warnings from abc import ABC, abstractmethod from collections.abc import Iterable from typing import Union @@ -15,6 +16,24 @@ class AbstractComponentsInitializer(ABC): ... +class LiteralCompInitializer(AbstractComponentsInitializer): + """'Generate' the provided components. + + Use this to 'generate' pre-initialized components elsewhere. + + """ + def __init__(self, components): + self.components = components + + def generate(self, num_components: int = 0): + """Ignore `num_components` and simply return `self.components`.""" + if not isinstance(self.components, torch.Tensor): + wmsg = f"Converting components to {torch.Tensor}..." + warnings.warn(wmsg) + self.components = torch.Tensor(self.components) + return self.components + + class ShapeAwareCompInitializer(AbstractComponentsInitializer): """Abstract class for all dimension-aware components initializers.""" def __init__(self, shape: Union[Iterable, int]): @@ -28,88 +47,6 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer): ... -class DataAwareCompInitializer(AbstractComponentsInitializer): - """Abstract class for all data-aware components initializers. - - Components generated by data-aware components initializers inherit the shape - of the provided data. - - """ - def __init__(self, - data, - noise: float = 0.0, - transform: callable = torch.nn.Identity()): - self.data = data - self.noise = noise - self.transform = transform - - def generate_end_hook(self, samples): - drift = torch.rand_like(samples) * self.noise - components = self.transform(samples + drift) - return components - - @abstractmethod - def generate(self, num_components: int): - ... - return self.generate_end_hook(...) - - def __del__(self): - del self.data - - -class ClassAwareCompInitializer(AbstractComponentsInitializer): - """Abstract class for all class-aware components initializers. - - Components generated by class-aware components initializers inherit the shape - of the provided data. - - """ - def __init__(self, - data, - noise: float = 0.0, - transform: callable = torch.nn.Identity()): - self.data, self.targets = parse_data_arg(data) - self.noise = noise - self.transform = transform - self.clabels = torch.unique(self.targets).int().tolist() - self.num_classes = len(self.clabels) - - @property - @abstractmethod - def subinit_type(self) -> DataAwareCompInitializer: - ... - - def generate(self, distribution: Union[dict, list, tuple]): - distribution = parse_distribution(distribution) - components = torch.tensor([]) - for k, v in distribution.items(): - stratified_data = self.data[self.targets == k] - initializer = self.subinit_type( - stratified_data, - noise=self.noise, - transform=self.transform, - ) - samples = initializer.generate(num_components=v) - components = torch.cat([components, samples]) - return components - - def __del__(self): - del self.data - del self.targets - - -class LiteralCompInitializer(DataAwareCompInitializer): - """'Generate' the provided components. - - Use this to 'generate' pre-initialized components from elsewhere. - - """ - def generate(self, num_components: int): - """Ignore `num_components` and simply return transformed `self.data`.""" - components = self.transform(self.data) - return components - - class ZerosCompInitializer(ShapeAwareCompInitializer): """Generate zeros corresponding to the components shape.""" def generate(self, num_components: int): @@ -163,7 +100,46 @@ class RandomNormalCompInitializer(OnesCompInitializer): return components -class SelectionCompInitializer(DataAwareCompInitializer): +class AbstractDataAwareCompInitializer(AbstractComponentsInitializer): + """Abstract class for all data-aware components initializers. + + Components generated by data-aware components initializers inherit the shape + of the provided data. + + `data` has to be a torch tensor. + + """ + def __init__(self, + data: torch.TensorType, + noise: float = 0.0, + transform: callable = torch.nn.Identity()): + self.data = data + self.noise = noise + self.transform = transform + + def generate_end_hook(self, samples): + drift = torch.rand_like(samples) * self.noise + components = self.transform(samples + drift) + return components + + @abstractmethod + def generate(self, num_components: int): + ... + return self.generate_end_hook(...) + + def __del__(self): + del self.data + + +class DataAwareCompInitializer(AbstractDataAwareCompInitializer): + """'Generate' the components from the provided data.""" + def generate(self, num_components: int = 0): + """Ignore `num_components` and simply return transformed `self.data`.""" + components = self.generate_end_hook(self.data) + return components + + +class SelectionCompInitializer(AbstractDataAwareCompInitializer): """Generate components by uniformly sampling from the provided data.""" def generate(self, num_components: int): indices = torch.LongTensor(num_components).random_(0, len(self.data)) @@ -172,7 +148,7 @@ class SelectionCompInitializer(DataAwareCompInitializer): return components -class MeanCompInitializer(DataAwareCompInitializer): +class MeanCompInitializer(AbstractDataAwareCompInitializer): """Generate components by computing the mean of the provided data.""" def generate(self, num_components: int): mean = torch.mean(self.data, dim=0) @@ -182,14 +158,74 @@ class MeanCompInitializer(DataAwareCompInitializer): return components -class StratifiedSelectionCompInitializer(ClassAwareCompInitializer): +class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer): + """Abstract class for all class-aware components initializers. + + Components generated by class-aware components initializers inherit the shape + of the provided data. + + `data` could be a torch Dataset or DataLoader or a list/tuple of data and + target tensors. + + """ + def __init__(self, + data, + noise: float = 0.0, + transform: callable = torch.nn.Identity()): + self.data, self.targets = parse_data_arg(data) + self.noise = noise + self.transform = transform + self.clabels = torch.unique(self.targets).int().tolist() + self.num_classes = len(self.clabels) + + @abstractmethod + def generate(self, distribution: Union[dict, list, tuple] = []): + ... + return self.generate_end_hook(...) + + def __del__(self): + del self.data + del self.targets + + +class ClassAwareCompInitializer(AbstractClassAwareCompInitializer): + """'Generate' components from provided data and requested distribution.""" + def generate(self, distribution: Union[dict, list, tuple] = []): + """Ignore `distribution` and simply return transformed `self.data`.""" + components = self.generate_end_hook(self.data) + return components + + +class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer): + """Abstract class for all stratified components initializers.""" + @property + @abstractmethod + def subinit_type(self) -> AbstractDataAwareCompInitializer: + ... + + def generate(self, distribution: Union[dict, list, tuple]): + distribution = parse_distribution(distribution) + components = torch.tensor([]) + for k, v in distribution.items(): + stratified_data = self.data[self.targets == k] + initializer = self.subinit_type( + stratified_data, + noise=self.noise, + transform=self.transform, + ) + samples = initializer.generate(num_components=v) + components = torch.cat([components, samples]) + return components + + +class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer): """Generate components using stratified sampling from the provided data.""" @property def subinit_type(self): return SelectionCompInitializer -class StratifiedMeanCompInitializer(ClassAwareCompInitializer): +class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer): """Generate components at stratified means of the provided data.""" @property def subinit_type(self): @@ -204,6 +240,38 @@ class AbstractLabelsInitializer(ABC): ... +class LiteralLabelsInitializer(AbstractLabelsInitializer): + """'Generate' the provided labels. + + Use this to 'generate' pre-initialized labels elsewhere. + + """ + def __init__(self, labels): + self.labels = labels + + def generate(self, distribution: Union[dict, list, tuple] = []): + """Ignore `distribution` and simply return `self.labels`. + + Convert to long tensor, if necessary. + """ + labels = self.labels + if not isinstance(labels, torch.LongTensor): + wmsg = f"Converting labels to {torch.LongTensor}..." + warnings.warn(wmsg) + labels = torch.LongTensor(labels) + return labels + + +class DataAwareLabelsInitializer(AbstractLabelsInitializer): + """'Generate' the labels from a torch Dataset.""" + def __init__(self, data): + self.data, self.targets = parse_data_arg(data) + + def generate(self, distribution: Union[dict, list, tuple] = []): + """Ignore `num_components` and simply return `self.targets`.""" + return self.targets + + class LabelsInitializer(AbstractLabelsInitializer): """Generate labels from `distribution`.""" def generate(self, distribution: Union[dict, list, tuple]): @@ -248,6 +316,27 @@ class AbstractReasoningsInitializer(ABC): return generate_end_hook(...) +class LiteralReasoningsInitializer(AbstractReasoningsInitializer): + """'Generate' the provided reasonings. + + Use this to 'generate' pre-initialized reasonings elsewhere. + + """ + def __init__(self, reasonings, **kwargs): + super().__init__(**kwargs) + self.reasonings = reasonings + + def generate(self, distribution: Union[dict, list, tuple] = []): + """Ignore `distributuion` and simply return self.reasonings.""" + reasonings = self.reasonings + if not isinstance(reasonings, torch.Tensor): + wmsg = f"Converting reasonings to {torch.Tensor}..." + warnings.warn(wmsg) + reasonings = torch.Tensor(reasonings) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + class ZerosReasoningsInitializer(AbstractReasoningsInitializer): """Reasonings are all initialized with zeros.""" def generate(self, distribution: Union[dict, list, tuple]): @@ -292,23 +381,28 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): # Aliases - Components -ZCI = ZerosCompInitializer -OCI = OnesCompInitializer +CACI = ClassAwareCompInitializer +DACI = DataAwareCompInitializer FVCI = FillValueCompInitializer LCI = LiteralCompInitializer -UCI = UniformCompInitializer +MCI = MeanCompInitializer +OCI = OnesCompInitializer RNCI = RandomNormalCompInitializer SCI = SelectionCompInitializer -MCI = MeanCompInitializer -SSCI = StratifiedSelectionCompInitializer SMCI = StratifiedMeanCompInitializer +SSCI = StratifiedSelectionCompInitializer +UCI = UniformCompInitializer +ZCI = ZerosCompInitializer # Aliases - Labels +DLI = DataAwareLabelsInitializer LI = LabelsInitializer +LLI = LiteralLabelsInitializer OHLI = OneHotLabelsInitializer # Aliases - Reasonings -ZRI = ZerosReasoningsInitializer +LRI = LiteralReasoningsInitializer ORI = OnesReasoningsInitializer -RRI = RandomReasoningsInitializer PPRI = PurePositiveReasoningsInitializer +RRI = RandomReasoningsInitializer +ZRI = ZerosReasoningsInitializer diff --git a/prototorch/utils/utils.py b/prototorch/utils/utils.py index 79c528b..d87b26f 100644 --- a/prototorch/utils/utils.py +++ b/prototorch/utils/utils.py @@ -67,17 +67,19 @@ def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str], elif isinstance(user_distribution, list): return distribution_from_list(user_distribution, clabels) else: - msg = f"`distribution` not understood." \ + msg = f"`distribution` was not understood." \ f"You have provided: {user_distribution}." raise ValueError(msg) def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]): + """Return data and target as torch tensors.""" if isinstance(data_arg, Dataset): ds_size = len(data_arg) - data_arg = DataLoader(data_arg, batch_size=ds_size) + loader = DataLoader(data_arg, batch_size=ds_size) + data, targets = next(iter(loader)) - if isinstance(data_arg, DataLoader): + elif isinstance(data_arg, DataLoader): data = torch.tensor([]) targets = torch.tensor([]) for x, y in data_arg: @@ -87,11 +89,11 @@ def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]): assert len(data_arg) == 2 data, targets = data_arg if not isinstance(data, torch.Tensor): - wmsg = f"Converting data to {torch.Tensor}." + wmsg = f"Converting data to {torch.Tensor}..." warnings.warn(wmsg) data = torch.Tensor(data) if not isinstance(targets, torch.LongTensor): - wmsg = f"Converting targets to {torch.LongTensor}." + wmsg = f"Converting targets to {torch.LongTensor}..." warnings.warn(wmsg) targets = torch.LongTensor(targets) return data, targets From d45e71256c77ad77eeca4ce287684239efc24662 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 14 Jun 2021 19:53:44 +0200 Subject: [PATCH 30/43] [TEST] Test literal initializers --- tests/test_core.py | 49 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 757e678..4862758 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -49,18 +49,41 @@ def test_parse_distribution_custom_labels(): # Components initializers +def test_literal_comp_generate(): + protos = torch.rand(4, 3, 5, 5) + c = pt.initializers.LiteralCompInitializer(protos) + components = c.generate() + assert torch.allclose(components, protos) + + +def test_literal_comp_generate_from_list(): + protos = [[0, 1], [2, 3], [4, 5]] + c = pt.initializers.LiteralCompInitializer(protos) + with pytest.warns(UserWarning): + components = c.generate() + assert torch.allclose(components, torch.Tensor(protos)) + + def test_shape_aware_raises_error(): with pytest.raises(TypeError): _ = pt.initializers.ShapeAwareCompInitializer(shape=(2, )) -def test_literal_comp_generate(): +def test_data_aware_comp_generate(): protos = torch.rand(4, 3, 5, 5) - c = pt.initializers.LiteralCompInitializer(protos) + c = pt.initializers.DataAwareCompInitializer(protos) components = c.generate(num_components="IgnoreMe!") assert torch.allclose(components, protos) +def test_class_aware_comp_generate(): + protos = torch.rand(4, 2, 3, 5, 5) + plabels = torch.tensor([0, 0, 1, 1]).long() + c = pt.initializers.ClassAwareCompInitializer([protos, plabels]) + components = c.generate(distribution=[]) + assert torch.allclose(components, protos) + + def test_zeros_comp_generate(): shape = (3, 5, 5) c = pt.initializers.ZerosCompInitializer(shape) @@ -136,6 +159,13 @@ def test_stratified_selection_comp_generate(): # Labels initializers +def test_literal_labels_init(): + l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2]) + with pytest.warns(UserWarning): + labels = l.generate() + assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2])) + + def test_labels_init_from_list(): l = pt.initializers.LabelsInitializer() components = l.generate(distribution=[1, 1, 1]) @@ -154,7 +184,22 @@ def test_labels_init_from_tuple_illegal(): _ = l.generate(distribution=(1, 1, 1)) +def test_data_aware_labels_init(): + data, targets = [0, 1, 2, 3], [0, 0, 1, 1] + ds = pt.datasets.NumpyDataset(data, targets) + l = pt.initializers.DataAwareLabelsInitializer(ds) + labels = l.generate() + assert torch.allclose(labels, torch.LongTensor(targets)) + + # Reasonings initializers +def test_literal_reasonings_init(): + r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2]) + with pytest.warns(UserWarning): + reasonings = r.generate() + assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2])) + + def test_random_reasonings_init(): r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8) reasonings = r.generate(distribution=[0, 1]) From 1f458ac0cc1c3a0da03c96dd8afad1341a226a5b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 14 Jun 2021 21:08:48 +0200 Subject: [PATCH 31/43] [FEATURE] Add distribution property to LabeledComponents --- prototorch/core/components.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index d0155a7..f1694ab 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -116,7 +116,7 @@ class AbstractLabels(torch.nn.Module): @property def num_labels(self): - return len(self.labels) + return len(self._labels) @property def unique_labels(self): @@ -193,6 +193,13 @@ class LabeledComponents(AbstractComponents): """Tensor containing the component labels.""" return self._labels + @property + def distribution(self): + unique, counts = torch.unique(self._labels, + sorted=True, + return_counts=True) + return dict(zip(unique.tolist(), counts.tolist())) + def _register_labels(self, labels): self.register_buffer("_labels", labels) From 0f450ed8a0bd673bdacb0a87e3e4e8a471ebfae9 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 15 Jun 2021 00:14:34 +0200 Subject: [PATCH 32/43] [BUGFIX] Remove dangerous mutable default arguments See https://stackoverflow.com/questions/1132941/least-astonishment-and-the-mutable-default-argument for more information. --- prototorch/core/initializers.py | 10 +++++----- tests/test_core.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index 8e65823..6a7067b 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -179,7 +179,7 @@ class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer): self.num_classes = len(self.clabels) @abstractmethod - def generate(self, distribution: Union[dict, list, tuple] = []): + def generate(self, distribution: Union[dict, list, tuple]): ... return self.generate_end_hook(...) @@ -190,7 +190,7 @@ class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer): class ClassAwareCompInitializer(AbstractClassAwareCompInitializer): """'Generate' components from provided data and requested distribution.""" - def generate(self, distribution: Union[dict, list, tuple] = []): + def generate(self, distribution: Union[dict, list, tuple]): """Ignore `distribution` and simply return transformed `self.data`.""" components = self.generate_end_hook(self.data) return components @@ -249,7 +249,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer): def __init__(self, labels): self.labels = labels - def generate(self, distribution: Union[dict, list, tuple] = []): + def generate(self, distribution: Union[dict, list, tuple]): """Ignore `distribution` and simply return `self.labels`. Convert to long tensor, if necessary. @@ -267,7 +267,7 @@ class DataAwareLabelsInitializer(AbstractLabelsInitializer): def __init__(self, data): self.data, self.targets = parse_data_arg(data) - def generate(self, distribution: Union[dict, list, tuple] = []): + def generate(self, distribution: Union[dict, list, tuple]): """Ignore `num_components` and simply return `self.targets`.""" return self.targets @@ -326,7 +326,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer): super().__init__(**kwargs) self.reasonings = reasonings - def generate(self, distribution: Union[dict, list, tuple] = []): + def generate(self, distribution: Union[dict, list, tuple]): """Ignore `distributuion` and simply return self.reasonings.""" reasonings = self.reasonings if not isinstance(reasonings, torch.Tensor): diff --git a/tests/test_core.py b/tests/test_core.py index 4862758..6fad03f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -52,7 +52,7 @@ def test_parse_distribution_custom_labels(): def test_literal_comp_generate(): protos = torch.rand(4, 3, 5, 5) c = pt.initializers.LiteralCompInitializer(protos) - components = c.generate() + components = c.generate([]) assert torch.allclose(components, protos) @@ -60,7 +60,7 @@ def test_literal_comp_generate_from_list(): protos = [[0, 1], [2, 3], [4, 5]] c = pt.initializers.LiteralCompInitializer(protos) with pytest.warns(UserWarning): - components = c.generate() + components = c.generate([]) assert torch.allclose(components, torch.Tensor(protos)) @@ -162,7 +162,7 @@ def test_stratified_selection_comp_generate(): def test_literal_labels_init(): l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2]) with pytest.warns(UserWarning): - labels = l.generate() + labels = l.generate([]) assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2])) @@ -188,7 +188,7 @@ def test_data_aware_labels_init(): data, targets = [0, 1, 2, 3], [0, 0, 1, 1] ds = pt.datasets.NumpyDataset(data, targets) l = pt.initializers.DataAwareLabelsInitializer(ds) - labels = l.generate() + labels = l.generate([]) assert torch.allclose(labels, torch.LongTensor(targets)) @@ -196,7 +196,7 @@ def test_data_aware_labels_init(): def test_literal_reasonings_init(): r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2]) with pytest.warns(UserWarning): - reasonings = r.generate() + reasonings = r.generate([]) assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2])) From 6e8a52e37137878c173e63536b0c84093b5edccd Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 15 Jun 2021 15:41:28 +0200 Subject: [PATCH 33/43] [FEATURE] Add standalone reasonings and CBC competition --- prototorch/__init__.py | 2 + prototorch/core/competitions.py | 28 +++++++++++ prototorch/core/components.py | 88 ++++++++++++++++++++++++++++----- 3 files changed, 105 insertions(+), 13 deletions(-) diff --git a/prototorch/__init__.py b/prototorch/__init__.py index d549de2..d0ce2d6 100644 --- a/prototorch/__init__.py +++ b/prototorch/__init__.py @@ -14,6 +14,7 @@ from .core import ( components, distances, initializers, + similarities, losses, pooling, ) @@ -31,6 +32,7 @@ __all_core__ = [ "losses", "nn", "pooling", + "similarities", "utils", ] diff --git a/prototorch/core/competitions.py b/prototorch/core/competitions.py index 2e354b6..2a54e10 100644 --- a/prototorch/core/competitions.py +++ b/prototorch/core/competitions.py @@ -28,6 +28,24 @@ def knnc(distances: torch.Tensor, return winning_labels +def cbcc(detections: torch.Tensor, reasonings: torch.Tensor): + """Classification-By-Components Competition. + + Returns probability distributions over the classes. + + `detections` must be of shape [batch_size, num_components]. + `reasonings` must be of shape [num_components, num_classes, 2]. + + """ + A, B = reasonings.permute(2, 1, 0).clamp(0, 1) + pk = A + nk = (1 - A) * B + numerator = (detections @ (pk - nk).T) + nk.sum(1) + probs = numerator / (pk + nk).sum(1) + # probs = probs.squeeze(0) + return probs + + class WTAC(torch.nn.Module): """Winner-Takes-All-Competition Layer. @@ -63,3 +81,13 @@ class KNNC(torch.nn.Module): def extra_repr(self): return f"k: {self.k}" + + +class CBCC(torch.nn.Module): + """Classification-By-Components Competition. + + Thin wrapper over the `cbcc` function. + + """ + def forward(self, detections, reasonings): + return cbcc(detections, reasonings) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index f1694ab..d497cdf 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -13,6 +13,7 @@ from .initializers import ( AbstractLabelsInitializer, AbstractReasoningsInitializer, LabelsInitializer, + RandomReasoningsInitializer, ) @@ -112,7 +113,7 @@ class AbstractLabels(torch.nn.Module): """Abstract class for all labels modules.""" @property def labels(self): - return self._labels + return self._labels.cpu() @property def num_labels(self): @@ -174,6 +175,10 @@ class Labels(AbstractLabels): self._register_labels(_labels) return mask + def forward(self): + """Simply return the labels.""" + return self._labels + class LabeledComponents(AbstractComponents): """A set of adaptable components and corresponding unadaptable labels.""" @@ -188,11 +193,6 @@ class LabeledComponents(AbstractComponents): self.add_components(distribution, components_initializer, labels_initializer) - @property - def labels(self): - """Tensor containing the component labels.""" - return self._labels - @property def distribution(self): unique, counts = torch.unique(self._labels, @@ -200,6 +200,15 @@ class LabeledComponents(AbstractComponents): return_counts=True) return dict(zip(unique.tolist(), counts.tolist())) + @property + def num_classes(self): + return len(self.distribution.keys()) + + @property + def labels(self): + """Tensor containing the component labels.""" + return self._labels.cpu() + def _register_labels(self, labels): self.register_buffer("_labels", labels) @@ -236,6 +245,64 @@ class LabeledComponents(AbstractComponents): return self._components, self._labels +class Reasonings(torch.nn.Module): + """A set of standalone reasoning matrices. + + The `reasonings` tensor is of shape [num_components, num_classes, 2]. + + """ + def __init__(self, + distribution: Union[dict, list, tuple], + initializer: + AbstractReasoningsInitializer = RandomReasoningsInitializer(), + **kwargs): + super().__init__(**kwargs) + + @property + def num_classes(self): + return self._reasonings.shape[1] + + # @property + # def reasonings(self): + # """Tensor containing the reasoning matrices.""" + # return self._reasonings.detach().cpu() + + @property + def reasonings(self): + with torch.no_grad(): + A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1) + pk = A + nk = (1 - pk) * B + ik = 1 - pk - nk + img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2) + return img.unsqueeze(1).cpu() + + def _register_reasonings(self, reasonings): + self.register_buffer("_reasonings", reasonings) + + def add_reasonings( + self, + distribution: Union[dict, list, tuple], + initializer: + AbstractReasoningsInitializer = RandomReasoningsInitializer()): + """Generate and add new reasonings.""" + assert validate_initializer(initializer, AbstractReasoningsInitializer) + _reasonings, new_reasonings = gencat(self, "_reasonings", initializer, + distribution) + self._register_reasonings(_reasonings) + return new_reasonings + + def remove_reasonings(self, indices): + """Remove reasonings at specified indices.""" + _reasonings, mask = removeind(self, "_reasonings", indices) + self._register_reasonings(_reasonings) + return mask + + def forward(self): + """Simply return the reasonings.""" + return self._reasonings + + class ReasoningComponents(AbstractComponents): """A set of components and a corresponding adapatable reasoning matrices. @@ -260,13 +327,8 @@ class ReasoningComponents(AbstractComponents): reasonings_initializer) @property - def reasonings(self): - """Returns Reasoning Matrix. - - Dimension NxCx2 - - """ - return self._reasonings.detach().cpu() + def num_classes(self): + return self._reasonings.shape[1] def _register_reasonings(self, reasonings): self.register_parameter("_reasonings", Parameter(reasonings)) From 42eb53d73a7fe5fcf2c968dc83beca10a3f4b9bc Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 15 Jun 2021 15:57:59 +0200 Subject: [PATCH 34/43] [FEATURE] Add `euclidean_similarity` and `margin_loss` --- prototorch/core/losses.py | 20 ++++++++++++++++++++ prototorch/core/similarities.py | 12 ++++++++++++ 2 files changed, 32 insertions(+) diff --git a/prototorch/core/losses.py b/prototorch/core/losses.py index ab3705f..1a32103 100644 --- a/prototorch/core/losses.py +++ b/prototorch/core/losses.py @@ -98,6 +98,13 @@ def rslvq_loss(probabilities, targets, prototype_labels): return -1.0 * log_likelihood +def margin_loss(y_pred, y_true, margin=0.3): + """Compute the margin loss.""" + dp = torch.sum(y_true * y_pred, dim=-1) + dm = torch.max(y_pred - y_true, dim=-1).values + return torch.nn.functional.relu(dm - dp + margin) + + class GLVQLoss(torch.nn.Module): def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs): super().__init__(**kwargs) @@ -112,6 +119,19 @@ class GLVQLoss(torch.nn.Module): return torch.sum(batch_loss, dim=0) +class MarginLoss(torch.nn.modules.loss._Loss): + def __init__(self, + margin=0.3, + size_average=None, + reduce=None, + reduction="mean"): + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, y_pred, y_true): + return margin_loss(y_pred, y_true, self.margin) + + class NeuralGasEnergy(torch.nn.Module): def __init__(self, lm, **kwargs): super().__init__(**kwargs) diff --git a/prototorch/core/similarities.py b/prototorch/core/similarities.py index 6125f8e..9929610 100644 --- a/prototorch/core/similarities.py +++ b/prototorch/core/similarities.py @@ -2,6 +2,18 @@ import torch +from .distances import euclidean_distance + + +def gaussian(x, variance=1.0): + return torch.exp(-(x * x) / (2 * variance)) + + +def euclidean_similarity(x, y, variance=1.0): + distances = euclidean_distance(x, y) + similarities = gaussian(distances, variance) + return similarities + def cosine_similarity(x, y): """Compute the cosine similarity between :math:`x` and :math:`y`. From 3a0e4a081ed91e5eaf979e4446e396dfe0f30183 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 16 Jun 2021 12:34:15 +0200 Subject: [PATCH 35/43] Improve error message --- prototorch/nn/activations.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/prototorch/nn/activations.py b/prototorch/nn/activations.py index 7931e14..ab70762 100644 --- a/prototorch/nn/activations.py +++ b/prototorch/nn/activations.py @@ -57,6 +57,10 @@ def get_activation(funcname): """Deserialize the activation function.""" if callable(funcname): return funcname - if funcname in ACTIVATIONS: + elif funcname in ACTIVATIONS: return ACTIVATIONS.get(funcname) - raise NameError(f"Activation {funcname} was not found.") + else: + emsg = f"Unable to find matching function for `{funcname}` " \ + f"in `prototorch.nn.activations`. " + helpmsg = f"Possible values are {list(ACTIVATIONS.keys())}." + raise NameError(emsg + helpmsg) From 70b4fa07e66acd97b0db09143423fdfb77965413 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 16 Jun 2021 12:34:33 +0200 Subject: [PATCH 36/43] Update gitignore --- .gitignore | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0b72579..7e4a483 100644 --- a/.gitignore +++ b/.gitignore @@ -146,6 +146,12 @@ dmypy.json # End of https://www.gitignore.io/api/visualstudiocode .vscode/ +# Vim +*~ +*.swp +*.swo + # ProtoTorch artifacts reports -artifacts \ No newline at end of file +artifacts +examples/_*.py \ No newline at end of file From 454718cdf54d725bada9eb001db89c3307032af1 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 16 Jun 2021 12:39:23 +0200 Subject: [PATCH 37/43] Update gitignore --- .gitignore | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 7e4a483..df059f9 100644 --- a/.gitignore +++ b/.gitignore @@ -151,7 +151,8 @@ dmypy.json *.swp *.swo -# ProtoTorch artifacts +# Artifacts created by ProtoTorch reports artifacts -examples/_*.py \ No newline at end of file +examples/_*.py +examples/_*.ipynb \ No newline at end of file From 7763a57058e1e47d27a8ce5dc8e1d32e320ec4c8 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 16 Jun 2021 13:39:09 +0200 Subject: [PATCH 38/43] [FEATURE] Add property `reasoning_matrices` --- prototorch/core/components.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index d497cdf..c9edcbb 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -262,20 +262,10 @@ class Reasonings(torch.nn.Module): def num_classes(self): return self._reasonings.shape[1] - # @property - # def reasonings(self): - # """Tensor containing the reasoning matrices.""" - # return self._reasonings.detach().cpu() - @property def reasonings(self): - with torch.no_grad(): - A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1) - pk = A - nk = (1 - pk) * B - ik = 1 - pk - nk - img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2) - return img.unsqueeze(1).cpu() + """Tensor containing the reasoning matrices.""" + return self._reasonings.detach().cpu() def _register_reasonings(self, reasonings): self.register_buffer("_reasonings", reasonings) @@ -330,6 +320,22 @@ class ReasoningComponents(AbstractComponents): def num_classes(self): return self._reasonings.shape[1] + @property + def reasonings(self): + """Tensor containing the reasoning matrices.""" + return self._reasonings.detach().cpu() + + @property + def reasoning_matrices(self): + """Reasoning matrices for each class.""" + with torch.no_grad(): + A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1) + pk = A + nk = (1 - pk) * B + ik = 1 - pk - nk + matrices = torch.stack([pk, nk, ik], dim=-1).permute(1, 2, 0) + return matrices.cpu() + def _register_reasonings(self, reasonings): self.register_parameter("_reasonings", Parameter(reasonings)) From c95f91cc299be08090438577bb794b93744fccda Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 16 Jun 2021 13:39:28 +0200 Subject: [PATCH 39/43] Update `examples/new_components.py` to use the new API --- examples/new_components.py | 55 ++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/examples/new_components.py b/examples/new_components.py index 2cb5f9f..ff47622 100644 --- a/examples/new_components.py +++ b/examples/new_components.py @@ -1,39 +1,35 @@ """This example script shows the usage of the new components architecture. Serialization/deserialization also works as expected. + """ -# DATASET import torch -from sklearn.datasets import load_iris -from sklearn.preprocessing import StandardScaler -scaler = StandardScaler() -x_train, y_train = load_iris(return_X_y=True) -x_train = x_train[:, [0, 2]] -scaler.fit(x_train) -x_train = scaler.transform(x_train) +import prototorch as pt -x_train = torch.Tensor(x_train) -y_train = torch.Tensor(y_train) -num_classes = len(torch.unique(y_train)) +ds = pt.datasets.Iris() -# CREATE NEW COMPONENTS -from prototorch.components import * -from prototorch.components.initializers import * - -unsupervised = Components(6, SelectionInitializer(x_train)) +unsupervised = pt.components.Components( + 6, + initializer=pt.initializers.ZCI(2), +) print(unsupervised()) -prototypes = LabeledComponents( - (3, 2), StratifiedSelectionInitializer(x_train, y_train)) +prototypes = pt.components.LabeledComponents( + (3, 2), + components_initializer=pt.initializers.SSCI(ds), +) print(prototypes()) -components = ReasoningComponents( - (3, 6), StratifiedSelectionInitializer(x_train, y_train)) -print(components()) +components = pt.components.ReasoningComponents( + (3, 2), + components_initializer=pt.initializers.SSCI(ds), + reasonings_initializer=pt.initializers.PPRI(), +) +print(prototypes()) -# TEST SERIALIZATION +# Test Serialization import io save = io.BytesIO() @@ -41,25 +37,20 @@ torch.save(unsupervised, save) save.seek(0) serialized_unsupervised = torch.load(save) -assert torch.all(unsupervised.components == serialized_unsupervised.components - ), "Serialization of Components failed." +assert torch.all(unsupervised.components == serialized_unsupervised.components) save = io.BytesIO() torch.save(prototypes, save) save.seek(0) serialized_prototypes = torch.load(save) -assert torch.all(prototypes.components == serialized_prototypes.components - ), "Serialization of Components failed." -assert torch.all(prototypes.component_labels == serialized_prototypes. - component_labels), "Serialization of Components failed." +assert torch.all(prototypes.components == serialized_prototypes.components) +assert torch.all(prototypes.labels == serialized_prototypes.labels) save = io.BytesIO() torch.save(components, save) save.seek(0) serialized_components = torch.load(save) -assert torch.all(components.components == serialized_components.components - ), "Serialization of Components failed." -assert torch.all(components.reasonings == serialized_components.reasonings - ), "Serialization of Components failed." +assert torch.all(components.components == serialized_components.components) +assert torch.all(components.reasonings == serialized_components.reasonings) From 7a6da0c5fcda6d4bdd7110efb33442a809265e94 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 16 Jun 2021 21:53:36 +0200 Subject: [PATCH 40/43] [FEATURE] Add transforms --- prototorch/__init__.py | 4 ++- prototorch/core/__init__.py | 1 + prototorch/core/initializers.py | 52 ++++++++++++++++++++++++++++++++- tests/test_core.py | 50 +++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 2 deletions(-) diff --git a/prototorch/__init__.py b/prototorch/__init__.py index d0ce2d6..3412d9c 100644 --- a/prototorch/__init__.py +++ b/prototorch/__init__.py @@ -14,9 +14,10 @@ from .core import ( components, distances, initializers, - similarities, losses, pooling, + similarities, + transforms, ) # Core Setup @@ -33,6 +34,7 @@ __all_core__ = [ "nn", "pooling", "similarities", + "transforms", "utils", ] diff --git a/prototorch/core/__init__.py b/prototorch/core/__init__.py index c205dfa..e5961c1 100644 --- a/prototorch/core/__init__.py +++ b/prototorch/core/__init__.py @@ -7,3 +7,4 @@ from .initializers import * from .losses import * from .pooling import * from .similarities import * +from .transforms import * diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index 6a7067b..f5d2743 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -313,7 +313,7 @@ class AbstractReasoningsInitializer(ABC): @abstractmethod def generate(self, distribution: Union[dict, list, tuple]): ... - return generate_end_hook(...) + return self.generate_end_hook(...) class LiteralReasoningsInitializer(AbstractReasoningsInitializer): @@ -380,6 +380,51 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): return reasonings +# Transforms +class AbstractTransformInitializer(ABC): + """Abstract class for all transform initializers.""" + ... + + +class AbstractLinearTransformInitializer(AbstractTransformInitializer): + """Abstract class for all linear transform initializers.""" + def __init__(self, out_dim_first: bool = False): + self.out_dim_first = out_dim_first + + def generate_end_hook(self, weights): + if self.out_dim_first: + weights = weights.permute(1, 0) + return weights + + @abstractmethod + def generate(self, in_dim: int, out_dim: int): + ... + return self.generate_end_hook(...) + + +class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer): + """Initialize a matrix with zeros.""" + def generate(self, in_dim: int, out_dim: int): + weights = torch.zeros(in_dim, out_dim) + return self.generate_end_hook(weights) + + +class OnesLinearTransformInitializer(AbstractLinearTransformInitializer): + """Initialize a matrix with ones.""" + def generate(self, in_dim: int, out_dim: int): + weights = torch.ones(in_dim, out_dim) + return self.generate_end_hook(weights) + + +class EyeTransformInitializer(AbstractLinearTransformInitializer): + """Initialize a matrix with the largest possible identity matrix.""" + def generate(self, in_dim: int, out_dim: int): + weights = torch.zeros(in_dim, out_dim) + I = torch.eye(min(in_dim, out_dim)) + weights[:I.shape[0], :I.shape[1]] = I + return self.generate_end_hook(weights) + + # Aliases - Components CACI = ClassAwareCompInitializer DACI = DataAwareCompInitializer @@ -406,3 +451,8 @@ ORI = OnesReasoningsInitializer PPRI = PurePositiveReasoningsInitializer RRI = RandomReasoningsInitializer ZRI = ZerosReasoningsInitializer + +# Aliases - Transforms +Eye = EyeTransformInitializer +OLTI = OnesLinearTransformInitializer +ZLTI = ZerosLinearTransformInitializer diff --git a/tests/test_core.py b/tests/test_core.py index 6fad03f..f949037 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -243,6 +243,56 @@ def test_pure_positive_reasonings_init_unrepresented_class(): assert reasonings.shape[2] == 3 +# Transform initializers +def test_eye_transform_init_square(): + t = pt.initializers.EyeTransformInitializer() + I = t.generate(3, 3) + assert torch.allclose(I, torch.eye(3)) + + +def test_eye_transform_init_narrow(): + t = pt.initializers.EyeTransformInitializer() + actual = t.generate(3, 2) + desired = torch.Tensor([[1, 0], [0, 1], [0, 0]]) + assert torch.allclose(actual, desired) + + +def test_eye_transform_init_wide(): + t = pt.initializers.EyeTransformInitializer() + actual = t.generate(2, 3) + desired = torch.Tensor([[1, 0, 0], [0, 1, 0]]) + assert torch.allclose(actual, desired) + + +# Transforms +def test_linear_transform(): + l = pt.transforms.LinearTransform(2, 4) + actual = l.weights + desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]]) + assert torch.allclose(actual, desired) + + +def test_linear_transform_zeros_init(): + l = pt.transforms.LinearTransform( + in_dim=2, + out_dim=4, + initializer=pt.initializers.ZerosLinearTransformInitializer(), + ) + actual = l.weights + desired = torch.zeros(2, 4) + assert torch.allclose(actual, desired) + + +def test_linear_transform_out_dim_first(): + l = pt.transforms.LinearTransform( + in_dim=2, + out_dim=4, + initializer=pt.initializers.OLTI(out_dim_first=True), + ) + assert l.weights.shape[0] == 4 + assert l.weights.shape[1] == 2 + + # Components def test_components_no_initializer(): with pytest.raises(TypeError): From 11cd1b0032a8bf1000af947522575d4bc4800281 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 16 Jun 2021 22:06:33 +0200 Subject: [PATCH 41/43] [BUGFIX] Add missing file --- prototorch/core/transforms.py | 44 +++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 prototorch/core/transforms.py diff --git a/prototorch/core/transforms.py b/prototorch/core/transforms.py new file mode 100644 index 0000000..3a0ded2 --- /dev/null +++ b/prototorch/core/transforms.py @@ -0,0 +1,44 @@ +"""ProtoTorch transforms""" + +import torch +from torch.nn.parameter import Parameter + +from .initializers import ( + AbstractLinearTransformInitializer, + EyeTransformInitializer, +) + + +class LinearTransform(torch.nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + initializer: + AbstractLinearTransformInitializer = EyeTransformInitializer(), + **kwargs): + super().__init__(**kwargs) + self.set_weights(in_dim, out_dim, initializer) + + @property + def weights(self): + return self._weights.detach().cpu() + + def _register_weights(self, weights): + self.register_parameter("_weights", Parameter(weights)) + + def set_weights( + self, + in_dim: int, + out_dim: int, + initializer: + AbstractLinearTransformInitializer = EyeTransformInitializer()): + weights = initializer.generate(in_dim, out_dim) + self._register_weights(weights) + + def forward(self, x): + return x @ self.weights.T + + +# Aliases +Omega = LinearTransform From ae11fedbf3232d6f9ba28658e739e19c9b4adbeb Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 17 Jun 2021 14:25:52 +0200 Subject: [PATCH 42/43] Add remarkrc --- .remarkrc | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .remarkrc diff --git a/.remarkrc b/.remarkrc new file mode 100644 index 0000000..5f7b470 --- /dev/null +++ b/.remarkrc @@ -0,0 +1,7 @@ +{ + "plugins": [ + "remark-preset-lint-recommended", + ["remark-lint-list-item-indent", false], + ["no-emphasis-as-header", false] + ] +} From de61bf7bca92126f44b19b823c399bb8da368186 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 17 Jun 2021 18:10:05 +0200 Subject: [PATCH 43/43] [BUGFIX] Fix reasonings initializer dimension bug --- prototorch/core/competitions.py | 1 - prototorch/core/components.py | 12 ++++++++---- prototorch/core/initializers.py | 4 ++-- tests/test_core.py | 26 +++++++++++++------------- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/prototorch/core/competitions.py b/prototorch/core/competitions.py index 2a54e10..3e57005 100644 --- a/prototorch/core/competitions.py +++ b/prototorch/core/competitions.py @@ -42,7 +42,6 @@ def cbcc(detections: torch.Tensor, reasonings: torch.Tensor): nk = (1 - A) * B numerator = (detections @ (pk - nk).T) + nk.sum(1) probs = numerator / (pk + nk).sum(1) - # probs = probs.squeeze(0) return probs diff --git a/prototorch/core/components.py b/prototorch/core/components.py index c9edcbb..330e89a 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -13,6 +13,7 @@ from .initializers import ( AbstractLabelsInitializer, AbstractReasoningsInitializer, LabelsInitializer, + PurePositiveReasoningsInitializer, RandomReasoningsInitializer, ) @@ -308,10 +309,13 @@ class ReasoningComponents(AbstractComponents): three element probability distribution. """ - def __init__(self, distribution: Union[dict, list, tuple], - components_initializer: AbstractComponentsInitializer, - reasonings_initializer: AbstractReasoningsInitializer, - **kwargs): + def __init__( + self, + distribution: Union[dict, list, tuple], + components_initializer: AbstractComponentsInitializer, + reasonings_initializer: + AbstractReasoningsInitializer = PurePositiveReasoningsInitializer(), + **kwargs): super().__init__(**kwargs) self.add_components(distribution, components_initializer, reasonings_initializer) diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index f5d2743..7041cbb 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -296,7 +296,7 @@ class OneHotLabelsInitializer(LabelsInitializer): # Reasonings class AbstractReasoningsInitializer(ABC): """Abstract class for all reasonings initializers.""" - def __init__(self, components_first=True): + def __init__(self, components_first: bool = True): self.components_first = components_first def compute_shape(self, distribution): @@ -375,7 +375,7 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): num_components, num_classes, _ = self.compute_shape(distribution) A = OneHotLabelsInitializer().generate(distribution) B = torch.zeros(num_components, num_classes) - reasonings = torch.stack([A, B]).permute(2, 1, 0) + reasonings = torch.stack([A, B], dim=-1) reasonings = self.generate_end_hook(reasonings) return reasonings diff --git a/tests/test_core.py b/tests/test_core.py index f949037..d007f9b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -220,13 +220,6 @@ def test_ones_reasonings_init(): assert torch.allclose(reasonings, torch.zeros(6, 3, 2)) -def test_random_reasonings_init_channels_not_first(): - r = pt.initializers.RandomReasoningsInitializer(components_first=False) - reasonings = r.generate(distribution=[1, 2]) - assert reasonings.shape[0] == 2 - assert reasonings.shape[-1] == 3 - - def test_pure_positive_reasonings_init_one_per_class(): r = pt.initializers.PurePositiveReasoningsInitializer( components_first=False) @@ -234,13 +227,20 @@ def test_pure_positive_reasonings_init_one_per_class(): assert torch.allclose(reasonings[0], torch.eye(4)) -def test_pure_positive_reasonings_init_unrepresented_class(): - r = pt.initializers.PurePositiveReasoningsInitializer( - components_first=False) - reasonings = r.generate(distribution=[1, 0, 1]) +def test_pure_positive_reasonings_init_unrepresented_classes(): + r = pt.initializers.PurePositiveReasoningsInitializer() + reasonings = r.generate(distribution=[9, 0, 0, 0]) + assert reasonings.shape[0] == 9 + assert reasonings.shape[1] == 4 + assert reasonings.shape[2] == 2 + + +def test_random_reasonings_init_channels_not_first(): + r = pt.initializers.RandomReasoningsInitializer(components_first=False) + reasonings = r.generate(distribution=[0, 0, 0, 1]) assert reasonings.shape[0] == 2 - assert reasonings.shape[1] == 2 - assert reasonings.shape[2] == 3 + assert reasonings.shape[1] == 4 + assert reasonings.shape[2] == 1 # Transform initializers