refactor(api)!: merge the new api changes into dev
BREAKING CHANGE: remove the following `prototorch/functions/*` `prototorch/components/*` `prototorch/modules/*` BREAKING CHANGE: move `initializers` into the `prototorch.initializers` namespace from the `prototorch.components` namespace BREAKING CHANGE: `functions` and `modules` and moved into `core` and `nn`
This commit is contained in:
		@@ -3,8 +3,8 @@ current_version = 0.5.1
 | 
				
			|||||||
commit = True
 | 
					commit = True
 | 
				
			||||||
tag = True
 | 
					tag = True
 | 
				
			||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
 | 
					parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
 | 
				
			||||||
serialize = 
 | 
					serialize = {major}.{minor}.{patch}
 | 
				
			||||||
	{major}.{minor}.{patch}
 | 
					message = bump: {current_version} → {new_version}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[bumpversion:file:setup.py]
 | 
					[bumpversion:file:setup.py]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										16
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -129,14 +129,6 @@ dmypy.json
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# End of https://www.gitignore.io/api/python
 | 
					# End of https://www.gitignore.io/api/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# ProtoFlow
 | 
					 | 
				
			||||||
core
 | 
					 | 
				
			||||||
checkpoint
 | 
					 | 
				
			||||||
logs/
 | 
					 | 
				
			||||||
saved_weights/
 | 
					 | 
				
			||||||
scratch*
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Created by https://www.gitignore.io/api/visualstudiocode
 | 
					# Created by https://www.gitignore.io/api/visualstudiocode
 | 
				
			||||||
# Edit at https://www.gitignore.io/?templates=visualstudiocode
 | 
					# Edit at https://www.gitignore.io/?templates=visualstudiocode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -154,5 +146,13 @@ scratch*
 | 
				
			|||||||
# End of https://www.gitignore.io/api/visualstudiocode
 | 
					# End of https://www.gitignore.io/api/visualstudiocode
 | 
				
			||||||
.vscode/
 | 
					.vscode/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Vim
 | 
				
			||||||
 | 
					*~
 | 
				
			||||||
 | 
					*.swp
 | 
				
			||||||
 | 
					*.swo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Artifacts created by ProtoTorch
 | 
				
			||||||
reports
 | 
					reports
 | 
				
			||||||
artifacts
 | 
					artifacts
 | 
				
			||||||
 | 
					examples/_*.py
 | 
				
			||||||
 | 
					examples/_*.ipynb
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -23,19 +23,19 @@ repos:
 | 
				
			|||||||
  -   id: isort
 | 
					  -   id: isort
 | 
				
			||||||
 | 
					
 | 
				
			||||||
-   repo: https://github.com/pre-commit/mirrors-mypy
 | 
					-   repo: https://github.com/pre-commit/mirrors-mypy
 | 
				
			||||||
    rev: 'v0.902'
 | 
					    rev: v0.902
 | 
				
			||||||
    hooks:
 | 
					    hooks:
 | 
				
			||||||
    -   id: mypy
 | 
					    -   id: mypy
 | 
				
			||||||
        files: prototorch
 | 
					        files: prototorch
 | 
				
			||||||
        additional_dependencies: [types-pkg_resources]
 | 
					        additional_dependencies: [types-pkg_resources]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
-   repo: https://github.com/pre-commit/mirrors-yapf
 | 
					-   repo: https://github.com/pre-commit/mirrors-yapf
 | 
				
			||||||
    rev: 'v0.31.0'  # Use the sha / tag you want to point at
 | 
					    rev: v0.31.0
 | 
				
			||||||
    hooks:
 | 
					    hooks:
 | 
				
			||||||
    -   id: yapf
 | 
					    -   id: yapf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
-   repo: https://github.com/pre-commit/pygrep-hooks
 | 
					-   repo: https://github.com/pre-commit/pygrep-hooks
 | 
				
			||||||
    rev: v1.9.0  # Use the ref you want to point at
 | 
					    rev: v1.9.0
 | 
				
			||||||
    hooks:
 | 
					    hooks:
 | 
				
			||||||
    -   id: python-use-type-annotations
 | 
					    -   id: python-use-type-annotations
 | 
				
			||||||
    -   id: python-no-log-warn
 | 
					    -   id: python-no-log-warn
 | 
				
			||||||
@@ -47,8 +47,8 @@ repos:
 | 
				
			|||||||
    hooks:
 | 
					    hooks:
 | 
				
			||||||
    -   id: pyupgrade
 | 
					    -   id: pyupgrade
 | 
				
			||||||
 | 
					
 | 
				
			||||||
-   repo: https://github.com/jorisroovers/gitlint
 | 
					-   repo: https://github.com/si-cim/gitlint
 | 
				
			||||||
    rev: "v0.15.1"
 | 
					    rev: v0.15.2-unofficial
 | 
				
			||||||
    hooks:
 | 
					    hooks:
 | 
				
			||||||
    -   id: gitlint
 | 
					    -   id: gitlint
 | 
				
			||||||
        args: [--contrib=CT1, --ignore=B6, --msg-filename]
 | 
					        args: [--contrib=CT1, --ignore=B6, --msg-filename]
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										7
									
								
								.remarkrc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								.remarkrc
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "plugins": [
 | 
				
			||||||
 | 
					    "remark-preset-lint-recommended",
 | 
				
			||||||
 | 
					    ["remark-lint-list-item-indent", false],
 | 
				
			||||||
 | 
					    ["no-emphasis-as-header", false]
 | 
				
			||||||
 | 
					  ]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,7 +1,7 @@
 | 
				
			|||||||
dist: bionic
 | 
					dist: bionic
 | 
				
			||||||
sudo: false
 | 
					sudo: false
 | 
				
			||||||
language: python
 | 
					language: python
 | 
				
			||||||
python: 3.8
 | 
					python: 3.9
 | 
				
			||||||
cache:
 | 
					cache:
 | 
				
			||||||
  directories:
 | 
					  directories:
 | 
				
			||||||
  - "$HOME/.cache/pip"
 | 
					  - "$HOME/.cache/pip"
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										10
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								README.md
									
									
									
									
									
								
							@@ -51,14 +51,20 @@ that link not work try <https://prototorch.readthedocs.io/en/latest/>.
 | 
				
			|||||||
## Contribution
 | 
					## Contribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
This repository contains definition for [git hooks](https://githooks.com).
 | 
					This repository contains definition for [git hooks](https://githooks.com).
 | 
				
			||||||
[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
 | 
					[Pre-commit](https://pre-commit.com) is automatically installed as development
 | 
				
			||||||
Please install the hooks by running
 | 
					dependency with prototorch or you can install it manually with `pip install
 | 
				
			||||||
 | 
					pre-commit`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Please install the hooks by running:
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
pre-commit install
 | 
					pre-commit install
 | 
				
			||||||
pre-commit install --hook-type commit-msg
 | 
					pre-commit install --hook-type commit-msg
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
before creating the first commit.
 | 
					before creating the first commit.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The commit will fail if the commit message does not follow the specification
 | 
				
			||||||
 | 
					provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Bibtex
 | 
					## Bibtex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
If you would like to cite the package, please use this:
 | 
					If you would like to cite the package, please use this:
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										96
									
								
								examples/cbc_iris.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										96
									
								
								examples/cbc_iris.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,96 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch CBC example using 2D Iris data."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CBC(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, data, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.components_layer = pt.components.ReasoningComponents(
 | 
				
			||||||
 | 
					            distribution=[2, 1, 2],
 | 
				
			||||||
 | 
					            components_initializer=pt.initializers.SSCI(data, noise=0.1),
 | 
				
			||||||
 | 
					            reasonings_initializer=pt.initializers.PPRI(components_first=True),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        components, reasonings = self.components_layer()
 | 
				
			||||||
 | 
					        sims = pt.similarities.euclidean_similarity(x, components)
 | 
				
			||||||
 | 
					        probs = pt.competitions.cbcc(sims, reasonings)
 | 
				
			||||||
 | 
					        return probs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VisCBC2D():
 | 
				
			||||||
 | 
					    def __init__(self, model, data):
 | 
				
			||||||
 | 
					        self.model = model
 | 
				
			||||||
 | 
					        self.x_train, self.y_train = pt.utils.parse_data_arg(data)
 | 
				
			||||||
 | 
					        self.title = "Components Visualization"
 | 
				
			||||||
 | 
					        self.fig = plt.figure(self.title)
 | 
				
			||||||
 | 
					        self.border = 0.1
 | 
				
			||||||
 | 
					        self.resolution = 100
 | 
				
			||||||
 | 
					        self.cmap = "viridis"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_epoch_end(self):
 | 
				
			||||||
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
 | 
					        _components = self.model.components_layer._components.detach()
 | 
				
			||||||
 | 
					        ax = self.fig.gca()
 | 
				
			||||||
 | 
					        ax.cla()
 | 
				
			||||||
 | 
					        ax.set_title(self.title)
 | 
				
			||||||
 | 
					        ax.axis("off")
 | 
				
			||||||
 | 
					        ax.scatter(
 | 
				
			||||||
 | 
					            x_train[:, 0],
 | 
				
			||||||
 | 
					            x_train[:, 1],
 | 
				
			||||||
 | 
					            c=y_train,
 | 
				
			||||||
 | 
					            cmap=self.cmap,
 | 
				
			||||||
 | 
					            edgecolor="k",
 | 
				
			||||||
 | 
					            marker="o",
 | 
				
			||||||
 | 
					            s=30,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        ax.scatter(
 | 
				
			||||||
 | 
					            _components[:, 0],
 | 
				
			||||||
 | 
					            _components[:, 1],
 | 
				
			||||||
 | 
					            c="w",
 | 
				
			||||||
 | 
					            cmap=self.cmap,
 | 
				
			||||||
 | 
					            edgecolor="k",
 | 
				
			||||||
 | 
					            marker="D",
 | 
				
			||||||
 | 
					            s=50,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        x = torch.vstack((x_train, _components))
 | 
				
			||||||
 | 
					        mesh_input, xx, yy = pt.utils.mesh2d(x, self.border, self.resolution)
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            y_pred = self.model(
 | 
				
			||||||
 | 
					                torch.Tensor(mesh_input).type_as(_components)).argmax(1)
 | 
				
			||||||
 | 
					        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
				
			||||||
 | 
					        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
 | 
					        plt.pause(0.2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = CBC(train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 | 
				
			||||||
 | 
					    criterion = pt.losses.MarginLoss(margin=0.1)
 | 
				
			||||||
 | 
					    vis = VisCBC2D(model, train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for epoch in range(200):
 | 
				
			||||||
 | 
					        correct = 0.0
 | 
				
			||||||
 | 
					        for x, y in train_loader:
 | 
				
			||||||
 | 
					            y_oh = torch.eye(3)[y]
 | 
				
			||||||
 | 
					            y_pred = model(x)
 | 
				
			||||||
 | 
					            loss = criterion(y_pred, y_oh).mean(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            optimizer.zero_grad()
 | 
				
			||||||
 | 
					            loss.backward()
 | 
				
			||||||
 | 
					            optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            correct += (y_pred.argmax(1) == y).float().sum(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        acc = 100 * correct / len(train_ds)
 | 
				
			||||||
 | 
					        print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
 | 
				
			||||||
 | 
					        vis.on_epoch_end()
 | 
				
			||||||
@@ -1,39 +1,35 @@
 | 
				
			|||||||
"""This example script shows the usage of the new components architecture.
 | 
					"""This example script shows the usage of the new components architecture.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Serialization/deserialization also works as expected.
 | 
					Serialization/deserialization also works as expected.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# DATASET
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					 | 
				
			||||||
from sklearn.preprocessing import StandardScaler
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
scaler = StandardScaler()
 | 
					import prototorch as pt
 | 
				
			||||||
x_train, y_train = load_iris(return_X_y=True)
 | 
					 | 
				
			||||||
x_train = x_train[:, [0, 2]]
 | 
					 | 
				
			||||||
scaler.fit(x_train)
 | 
					 | 
				
			||||||
x_train = scaler.transform(x_train)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
x_train = torch.Tensor(x_train)
 | 
					ds = pt.datasets.Iris()
 | 
				
			||||||
y_train = torch.Tensor(y_train)
 | 
					 | 
				
			||||||
num_classes = len(torch.unique(y_train))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# CREATE NEW COMPONENTS
 | 
					unsupervised = pt.components.Components(
 | 
				
			||||||
from prototorch.components import *
 | 
					    6,
 | 
				
			||||||
from prototorch.components.initializers import *
 | 
					    initializer=pt.initializers.ZCI(2),
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
unsupervised = Components(6, SelectionInitializer(x_train))
 | 
					 | 
				
			||||||
print(unsupervised())
 | 
					print(unsupervised())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
prototypes = LabeledComponents(
 | 
					prototypes = pt.components.LabeledComponents(
 | 
				
			||||||
    (3, 2), StratifiedSelectionInitializer(x_train, y_train))
 | 
					    (3, 2),
 | 
				
			||||||
 | 
					    components_initializer=pt.initializers.SSCI(ds),
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
print(prototypes())
 | 
					print(prototypes())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
components = ReasoningComponents(
 | 
					components = pt.components.ReasoningComponents(
 | 
				
			||||||
    (3, 6), StratifiedSelectionInitializer(x_train, y_train))
 | 
					    (3, 2),
 | 
				
			||||||
print(components())
 | 
					    components_initializer=pt.initializers.SSCI(ds),
 | 
				
			||||||
 | 
					    reasonings_initializer=pt.initializers.PPRI(),
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					print(prototypes())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TEST SERIALIZATION
 | 
					# Test Serialization
 | 
				
			||||||
import io
 | 
					import io
 | 
				
			||||||
 | 
					
 | 
				
			||||||
save = io.BytesIO()
 | 
					save = io.BytesIO()
 | 
				
			||||||
@@ -41,25 +37,20 @@ torch.save(unsupervised, save)
 | 
				
			|||||||
save.seek(0)
 | 
					save.seek(0)
 | 
				
			||||||
serialized_unsupervised = torch.load(save)
 | 
					serialized_unsupervised = torch.load(save)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
assert torch.all(unsupervised.components == serialized_unsupervised.components
 | 
					assert torch.all(unsupervised.components == serialized_unsupervised.components)
 | 
				
			||||||
                 ), "Serialization of Components failed."
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
save = io.BytesIO()
 | 
					save = io.BytesIO()
 | 
				
			||||||
torch.save(prototypes, save)
 | 
					torch.save(prototypes, save)
 | 
				
			||||||
save.seek(0)
 | 
					save.seek(0)
 | 
				
			||||||
serialized_prototypes = torch.load(save)
 | 
					serialized_prototypes = torch.load(save)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
assert torch.all(prototypes.components == serialized_prototypes.components
 | 
					assert torch.all(prototypes.components == serialized_prototypes.components)
 | 
				
			||||||
                 ), "Serialization of Components failed."
 | 
					assert torch.all(prototypes.labels == serialized_prototypes.labels)
 | 
				
			||||||
assert torch.all(prototypes.component_labels == serialized_prototypes.
 | 
					 | 
				
			||||||
                 component_labels), "Serialization of Components failed."
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
save = io.BytesIO()
 | 
					save = io.BytesIO()
 | 
				
			||||||
torch.save(components, save)
 | 
					torch.save(components, save)
 | 
				
			||||||
save.seek(0)
 | 
					save.seek(0)
 | 
				
			||||||
serialized_components = torch.load(save)
 | 
					serialized_components = torch.load(save)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
assert torch.all(components.components == serialized_components.components
 | 
					assert torch.all(components.components == serialized_components.components)
 | 
				
			||||||
                 ), "Serialization of Components failed."
 | 
					assert torch.all(components.reasonings == serialized_components.reasonings)
 | 
				
			||||||
assert torch.all(components.reasonings == serialized_components.reasonings
 | 
					 | 
				
			||||||
                 ), "Serialization of Components failed."
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,21 +1,41 @@
 | 
				
			|||||||
"""ProtoTorch package."""
 | 
					"""ProtoTorch package"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pkgutil
 | 
					import pkgutil
 | 
				
			||||||
from typing import List
 | 
					from typing import List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pkg_resources
 | 
					import pkg_resources
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import components, datasets, functions, modules, utils
 | 
					from . import (
 | 
				
			||||||
from .datasets import *
 | 
					    datasets,
 | 
				
			||||||
 | 
					    nn,
 | 
				
			||||||
 | 
					    utils,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from .core import (
 | 
				
			||||||
 | 
					    competitions,
 | 
				
			||||||
 | 
					    components,
 | 
				
			||||||
 | 
					    distances,
 | 
				
			||||||
 | 
					    initializers,
 | 
				
			||||||
 | 
					    losses,
 | 
				
			||||||
 | 
					    pooling,
 | 
				
			||||||
 | 
					    similarities,
 | 
				
			||||||
 | 
					    transforms,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Core Setup
 | 
					# Core Setup
 | 
				
			||||||
__version__ = "0.5.1"
 | 
					__version__ = "0.5.1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all_core__ = [
 | 
					__all_core__ = [
 | 
				
			||||||
    "datasets",
 | 
					    "competitions",
 | 
				
			||||||
    "functions",
 | 
					 | 
				
			||||||
    "modules",
 | 
					 | 
				
			||||||
    "components",
 | 
					    "components",
 | 
				
			||||||
 | 
					    "core",
 | 
				
			||||||
 | 
					    "datasets",
 | 
				
			||||||
 | 
					    "distances",
 | 
				
			||||||
 | 
					    "initializers",
 | 
				
			||||||
 | 
					    "losses",
 | 
				
			||||||
 | 
					    "nn",
 | 
				
			||||||
 | 
					    "pooling",
 | 
				
			||||||
 | 
					    "similarities",
 | 
				
			||||||
 | 
					    "transforms",
 | 
				
			||||||
    "utils",
 | 
					    "utils",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,2 +0,0 @@
 | 
				
			|||||||
from prototorch.components.components import *
 | 
					 | 
				
			||||||
from prototorch.components.initializers import *
 | 
					 | 
				
			||||||
@@ -1,270 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch components modules."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from torch.nn.parameter import Parameter
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from prototorch.components.initializers import (ClassAwareInitializer,
 | 
					 | 
				
			||||||
                                                ComponentsInitializer,
 | 
					 | 
				
			||||||
                                                EqualLabelsInitializer,
 | 
					 | 
				
			||||||
                                                UnequalLabelsInitializer,
 | 
					 | 
				
			||||||
                                                ZeroReasoningsInitializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .initializers import parse_data_arg
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_labels_object(distribution):
 | 
					 | 
				
			||||||
    if isinstance(distribution, dict):
 | 
					 | 
				
			||||||
        if "num_classes" in distribution.keys():
 | 
					 | 
				
			||||||
            labels = EqualLabelsInitializer(
 | 
					 | 
				
			||||||
                distribution["num_classes"],
 | 
					 | 
				
			||||||
                distribution["prototypes_per_class"])
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            clabels = list(distribution.keys())
 | 
					 | 
				
			||||||
            dist = list(distribution.values())
 | 
					 | 
				
			||||||
            labels = UnequalLabelsInitializer(dist, clabels)
 | 
					 | 
				
			||||||
    elif isinstance(distribution, tuple):
 | 
					 | 
				
			||||||
        num_classes, prototypes_per_class = distribution
 | 
					 | 
				
			||||||
        labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
 | 
					 | 
				
			||||||
    elif isinstance(distribution, list):
 | 
					 | 
				
			||||||
        labels = UnequalLabelsInitializer(distribution)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        msg = f"`distribution` not understood." \
 | 
					 | 
				
			||||||
            f"You have provided: {distribution=}."
 | 
					 | 
				
			||||||
        raise ValueError(msg)
 | 
					 | 
				
			||||||
    return labels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _precheck_initializer(initializer):
 | 
					 | 
				
			||||||
    if not isinstance(initializer, ComponentsInitializer):
 | 
					 | 
				
			||||||
        emsg = f"`initializer` has to be some subtype of " \
 | 
					 | 
				
			||||||
            f"{ComponentsInitializer}. " \
 | 
					 | 
				
			||||||
            f"You have provided: {initializer=} instead."
 | 
					 | 
				
			||||||
        raise TypeError(emsg)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class LinearMapping(torch.nn.Module):
 | 
					 | 
				
			||||||
    """LinearMapping is a learnable Mapping Matrix."""
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 mapping_shape=None,
 | 
					 | 
				
			||||||
                 initializer=None,
 | 
					 | 
				
			||||||
                 *,
 | 
					 | 
				
			||||||
                 initialized_linearmapping=None):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Ignore all initialization settings if initialized_components is given.
 | 
					 | 
				
			||||||
        if initialized_linearmapping is not None:
 | 
					 | 
				
			||||||
            self._register_mapping(initialized_linearmapping)
 | 
					 | 
				
			||||||
            if num_components is not None or initializer is not None:
 | 
					 | 
				
			||||||
                wmsg = "Arguments ignored while initializing Components"
 | 
					 | 
				
			||||||
                warnings.warn(wmsg)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self._initialize_mapping(mapping_shape, initializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def mapping_shape(self):
 | 
					 | 
				
			||||||
        return self._omega.shape
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _register_mapping(self, components):
 | 
					 | 
				
			||||||
        self.register_parameter("_omega", Parameter(components))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _initialize_mapping(self, mapping_shape, initializer):
 | 
					 | 
				
			||||||
        _precheck_initializer(initializer)
 | 
					 | 
				
			||||||
        _mapping = initializer.generate(mapping_shape)
 | 
					 | 
				
			||||||
        self._register_mapping(_mapping)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def mapping(self):
 | 
					 | 
				
			||||||
        """Tensor containing the component tensors."""
 | 
					 | 
				
			||||||
        return self._omega.detach()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self):
 | 
					 | 
				
			||||||
        return self._omega
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Components(torch.nn.Module):
 | 
					 | 
				
			||||||
    """Components is a set of learnable Tensors."""
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 num_components=None,
 | 
					 | 
				
			||||||
                 initializer=None,
 | 
					 | 
				
			||||||
                 *,
 | 
					 | 
				
			||||||
                 initialized_components=None):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Ignore all initialization settings if initialized_components is given.
 | 
					 | 
				
			||||||
        if initialized_components is not None:
 | 
					 | 
				
			||||||
            self._register_components(initialized_components)
 | 
					 | 
				
			||||||
            if num_components is not None or initializer is not None:
 | 
					 | 
				
			||||||
                wmsg = "Arguments ignored while initializing Components"
 | 
					 | 
				
			||||||
                warnings.warn(wmsg)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self._initialize_components(num_components, initializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def num_components(self):
 | 
					 | 
				
			||||||
        return len(self._components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _register_components(self, components):
 | 
					 | 
				
			||||||
        self.register_parameter("_components", Parameter(components))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _initialize_components(self, num_components, initializer):
 | 
					 | 
				
			||||||
        _precheck_initializer(initializer)
 | 
					 | 
				
			||||||
        _components = initializer.generate(num_components)
 | 
					 | 
				
			||||||
        self._register_components(_components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_components(self,
 | 
					 | 
				
			||||||
                       num=1,
 | 
					 | 
				
			||||||
                       initializer=None,
 | 
					 | 
				
			||||||
                       *,
 | 
					 | 
				
			||||||
                       initialized_components=None):
 | 
					 | 
				
			||||||
        if initialized_components is not None:
 | 
					 | 
				
			||||||
            _components = torch.cat([self._components, initialized_components])
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            _precheck_initializer(initializer)
 | 
					 | 
				
			||||||
            _new = initializer.generate(num)
 | 
					 | 
				
			||||||
            _components = torch.cat([self._components, _new])
 | 
					 | 
				
			||||||
        self._register_components(_components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def remove_components(self, indices=None):
 | 
					 | 
				
			||||||
        mask = torch.ones(self.num_components, dtype=torch.bool)
 | 
					 | 
				
			||||||
        mask[indices] = False
 | 
					 | 
				
			||||||
        _components = self._components[mask]
 | 
					 | 
				
			||||||
        self._register_components(_components)
 | 
					 | 
				
			||||||
        return mask
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def components(self):
 | 
					 | 
				
			||||||
        """Tensor containing the component tensors."""
 | 
					 | 
				
			||||||
        return self._components.detach()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self):
 | 
					 | 
				
			||||||
        return self._components
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def extra_repr(self):
 | 
					 | 
				
			||||||
        return f"(components): (shape: {tuple(self._components.shape)})"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class LabeledComponents(Components):
 | 
					 | 
				
			||||||
    """LabeledComponents generate a set of components and a set of labels.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Every Component has a label assigned.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 distribution=None,
 | 
					 | 
				
			||||||
                 initializer=None,
 | 
					 | 
				
			||||||
                 *,
 | 
					 | 
				
			||||||
                 initialized_components=None):
 | 
					 | 
				
			||||||
        if initialized_components is not None:
 | 
					 | 
				
			||||||
            components, component_labels = parse_data_arg(
 | 
					 | 
				
			||||||
                initialized_components)
 | 
					 | 
				
			||||||
            super().__init__(initialized_components=components)
 | 
					 | 
				
			||||||
            self._register_labels(component_labels)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            labels = get_labels_object(distribution)
 | 
					 | 
				
			||||||
            self.initial_distribution = labels.distribution
 | 
					 | 
				
			||||||
            _labels = labels.generate()
 | 
					 | 
				
			||||||
            super().__init__(len(_labels), initializer=initializer)
 | 
					 | 
				
			||||||
            self._register_labels(_labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _register_labels(self, labels):
 | 
					 | 
				
			||||||
        self.register_buffer("_labels", labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def distribution(self):
 | 
					 | 
				
			||||||
        clabels, counts = torch.unique(self._labels,
 | 
					 | 
				
			||||||
                                       sorted=True,
 | 
					 | 
				
			||||||
                                       return_counts=True)
 | 
					 | 
				
			||||||
        return dict(zip(clabels.tolist(), counts.tolist()))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _initialize_components(self, num_components, initializer):
 | 
					 | 
				
			||||||
        if isinstance(initializer, ClassAwareInitializer):
 | 
					 | 
				
			||||||
            _precheck_initializer(initializer)
 | 
					 | 
				
			||||||
            _components = initializer.generate(num_components,
 | 
					 | 
				
			||||||
                                               self.initial_distribution)
 | 
					 | 
				
			||||||
            self._register_components(_components)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            super()._initialize_components(num_components, initializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_components(self, distribution, initializer):
 | 
					 | 
				
			||||||
        _precheck_initializer(initializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Labels
 | 
					 | 
				
			||||||
        labels = get_labels_object(distribution)
 | 
					 | 
				
			||||||
        new_labels = labels.generate()
 | 
					 | 
				
			||||||
        _labels = torch.cat([self._labels, new_labels])
 | 
					 | 
				
			||||||
        self._register_labels(_labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Components
 | 
					 | 
				
			||||||
        if isinstance(initializer, ClassAwareInitializer):
 | 
					 | 
				
			||||||
            _new = initializer.generate(len(new_labels), distribution)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            _new = initializer.generate(len(new_labels))
 | 
					 | 
				
			||||||
        _components = torch.cat([self._components, _new])
 | 
					 | 
				
			||||||
        self._register_components(_components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def remove_components(self, indices=None):
 | 
					 | 
				
			||||||
        # Components
 | 
					 | 
				
			||||||
        mask = super().remove_components(indices)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Labels
 | 
					 | 
				
			||||||
        _labels = self._labels[mask]
 | 
					 | 
				
			||||||
        self._register_labels(_labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def component_labels(self):
 | 
					 | 
				
			||||||
        """Tensor containing the component tensors."""
 | 
					 | 
				
			||||||
        return self._labels.detach()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self):
 | 
					 | 
				
			||||||
        return super().forward(), self._labels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ReasoningComponents(Components):
 | 
					 | 
				
			||||||
    r"""ReasoningComponents generate a set of components and a set of reasoning matrices.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Every Component has a reasoning matrix assigned.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The
 | 
					 | 
				
			||||||
    first element is called positive reasoning :math:`p`, the second negative
 | 
					 | 
				
			||||||
    reasoning :math:`n`. A components can reason in favour (positive) of a
 | 
					 | 
				
			||||||
    class, against (negative) a class or not at all (neutral).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
 | 
					 | 
				
			||||||
    \leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
 | 
					 | 
				
			||||||
    three element probability distribution.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 reasonings=None,
 | 
					 | 
				
			||||||
                 initializer=None,
 | 
					 | 
				
			||||||
                 *,
 | 
					 | 
				
			||||||
                 initialized_components=None):
 | 
					 | 
				
			||||||
        if initialized_components is not None:
 | 
					 | 
				
			||||||
            components, reasonings = initialized_components
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            super().__init__(initialized_components=components)
 | 
					 | 
				
			||||||
            self.register_parameter("_reasonings", reasonings)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self._initialize_reasonings(reasonings)
 | 
					 | 
				
			||||||
            super().__init__(len(self._reasonings), initializer=initializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _initialize_reasonings(self, reasonings):
 | 
					 | 
				
			||||||
        if isinstance(reasonings, tuple):
 | 
					 | 
				
			||||||
            num_classes, num_components = reasonings
 | 
					 | 
				
			||||||
            reasonings = ZeroReasoningsInitializer(num_classes, num_components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        _reasonings = reasonings.generate()
 | 
					 | 
				
			||||||
        self.register_parameter("_reasonings", _reasonings)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def reasonings(self):
 | 
					 | 
				
			||||||
        """Returns Reasoning Matrix.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Dimension NxCx2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self._reasonings.detach()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self):
 | 
					 | 
				
			||||||
        return super().forward(), self._reasonings
 | 
					 | 
				
			||||||
@@ -1,233 +0,0 @@
 | 
				
			|||||||
"""ProtoTroch Initializers."""
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from collections.abc import Iterable
 | 
					 | 
				
			||||||
from itertools import chain
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from torch.utils.data import DataLoader, Dataset
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_data_arg(data_arg):
 | 
					 | 
				
			||||||
    if isinstance(data_arg, Dataset):
 | 
					 | 
				
			||||||
        data_arg = DataLoader(data_arg, batch_size=len(data_arg))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if isinstance(data_arg, DataLoader):
 | 
					 | 
				
			||||||
        data = torch.tensor([])
 | 
					 | 
				
			||||||
        targets = torch.tensor([])
 | 
					 | 
				
			||||||
        for x, y in data_arg:
 | 
					 | 
				
			||||||
            data = torch.cat([data, x])
 | 
					 | 
				
			||||||
            targets = torch.cat([targets, y])
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        data, targets = data_arg
 | 
					 | 
				
			||||||
        if not isinstance(data, torch.Tensor):
 | 
					 | 
				
			||||||
            wmsg = f"Converting data to {torch.Tensor}."
 | 
					 | 
				
			||||||
            warnings.warn(wmsg)
 | 
					 | 
				
			||||||
            data = torch.Tensor(data)
 | 
					 | 
				
			||||||
        if not isinstance(targets, torch.Tensor):
 | 
					 | 
				
			||||||
            wmsg = f"Converting targets to {torch.Tensor}."
 | 
					 | 
				
			||||||
            warnings.warn(wmsg)
 | 
					 | 
				
			||||||
            targets = torch.Tensor(targets)
 | 
					 | 
				
			||||||
    return data, targets
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_subinitializers(data, targets, clabels, subinit_type):
 | 
					 | 
				
			||||||
    initializers = dict()
 | 
					 | 
				
			||||||
    for clabel in clabels:
 | 
					 | 
				
			||||||
        class_data = data[targets == clabel]
 | 
					 | 
				
			||||||
        class_initializer = subinit_type(class_data)
 | 
					 | 
				
			||||||
        initializers[clabel] = (class_initializer)
 | 
					 | 
				
			||||||
    return initializers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Components
 | 
					 | 
				
			||||||
class ComponentsInitializer(object):
 | 
					 | 
				
			||||||
    def generate(self, number_of_components):
 | 
					 | 
				
			||||||
        raise NotImplementedError("Subclasses should implement this!")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class DimensionAwareInitializer(ComponentsInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, dims):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        if isinstance(dims, Iterable):
 | 
					 | 
				
			||||||
            self.components_dims = tuple(dims)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.components_dims = (dims, )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class OnesInitializer(DimensionAwareInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, dims, scale=1.0):
 | 
					 | 
				
			||||||
        super().__init__(dims)
 | 
					 | 
				
			||||||
        self.scale = scale
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        gen_dims = (length, ) + self.components_dims
 | 
					 | 
				
			||||||
        return torch.ones(gen_dims) * self.scale
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ZerosInitializer(DimensionAwareInitializer):
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        gen_dims = (length, ) + self.components_dims
 | 
					 | 
				
			||||||
        return torch.zeros(gen_dims)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class UniformInitializer(DimensionAwareInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0):
 | 
					 | 
				
			||||||
        super().__init__(dims)
 | 
					 | 
				
			||||||
        self.minimum = minimum
 | 
					 | 
				
			||||||
        self.maximum = maximum
 | 
					 | 
				
			||||||
        self.scale = scale
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        gen_dims = (length, ) + self.components_dims
 | 
					 | 
				
			||||||
        return torch.ones(gen_dims).uniform_(self.minimum,
 | 
					 | 
				
			||||||
                                             self.maximum) * self.scale
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class DataAwareInitializer(ComponentsInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, data, transform=torch.nn.Identity()):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.data = data
 | 
					 | 
				
			||||||
        self.transform = transform
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __del__(self):
 | 
					 | 
				
			||||||
        del self.data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SelectionInitializer(DataAwareInitializer):
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        indices = torch.LongTensor(length).random_(0, len(self.data))
 | 
					 | 
				
			||||||
        return self.transform(self.data[indices])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class MeanInitializer(DataAwareInitializer):
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        mean = torch.mean(self.data, dim=0)
 | 
					 | 
				
			||||||
        repeat_dim = [length] + [1] * len(mean.shape)
 | 
					 | 
				
			||||||
        return self.transform(mean.repeat(repeat_dim))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ClassAwareInitializer(DataAwareInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, data, transform=torch.nn.Identity()):
 | 
					 | 
				
			||||||
        data, targets = parse_data_arg(data)
 | 
					 | 
				
			||||||
        super().__init__(data, transform)
 | 
					 | 
				
			||||||
        self.targets = targets
 | 
					 | 
				
			||||||
        self.clabels = torch.unique(self.targets).int().tolist()
 | 
					 | 
				
			||||||
        self.num_classes = len(self.clabels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _get_samples_from_initializer(self, length, dist):
 | 
					 | 
				
			||||||
        if not dist:
 | 
					 | 
				
			||||||
            per_class = length // self.num_classes
 | 
					 | 
				
			||||||
            dist = dict(zip(self.clabels, self.num_classes * [per_class]))
 | 
					 | 
				
			||||||
        if isinstance(dist, list):
 | 
					 | 
				
			||||||
            dist = dict(zip(self.clabels, dist))
 | 
					 | 
				
			||||||
        samples = [self.initializers[k].generate(n) for k, n in dist.items()]
 | 
					 | 
				
			||||||
        out = torch.vstack(samples)
 | 
					 | 
				
			||||||
        with torch.no_grad():
 | 
					 | 
				
			||||||
            out = self.transform(out)
 | 
					 | 
				
			||||||
        return out
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __del__(self):
 | 
					 | 
				
			||||||
        del self.data
 | 
					 | 
				
			||||||
        del self.targets
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StratifiedMeanInitializer(ClassAwareInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, data, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(data, **kwargs)
 | 
					 | 
				
			||||||
        self.initializers = get_subinitializers(self.data, self.targets,
 | 
					 | 
				
			||||||
                                                self.clabels, MeanInitializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self, length, dist):
 | 
					 | 
				
			||||||
        samples = self._get_samples_from_initializer(length, dist)
 | 
					 | 
				
			||||||
        return samples
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, data, noise=None, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(data, **kwargs)
 | 
					 | 
				
			||||||
        self.noise = noise
 | 
					 | 
				
			||||||
        self.initializers = get_subinitializers(self.data, self.targets,
 | 
					 | 
				
			||||||
                                                self.clabels,
 | 
					 | 
				
			||||||
                                                SelectionInitializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_noise_v1(self, x):
 | 
					 | 
				
			||||||
        return x + self.noise
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_noise_v2(self, x):
 | 
					 | 
				
			||||||
        """Shifts some dimensions of the data randomly."""
 | 
					 | 
				
			||||||
        n1 = torch.rand_like(x)
 | 
					 | 
				
			||||||
        n2 = torch.rand_like(x)
 | 
					 | 
				
			||||||
        mask = torch.bernoulli(n1) - torch.bernoulli(n2)
 | 
					 | 
				
			||||||
        return x + (self.noise * mask)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self, length, dist):
 | 
					 | 
				
			||||||
        samples = self._get_samples_from_initializer(length, dist)
 | 
					 | 
				
			||||||
        if self.noise is not None:
 | 
					 | 
				
			||||||
            samples = self.add_noise_v1(samples)
 | 
					 | 
				
			||||||
        return samples
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Omega matrix
 | 
					 | 
				
			||||||
class PcaInitializer(DataAwareInitializer):
 | 
					 | 
				
			||||||
    def generate(self, shape):
 | 
					 | 
				
			||||||
        (input_dim, latent_dim) = shape
 | 
					 | 
				
			||||||
        (_, eigVal, eigVec) = torch.pca_lowrank(self.data, q=latent_dim)
 | 
					 | 
				
			||||||
        return eigVec
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Labels
 | 
					 | 
				
			||||||
class LabelsInitializer:
 | 
					 | 
				
			||||||
    def generate(self):
 | 
					 | 
				
			||||||
        raise NotImplementedError("Subclasses should implement this!")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class UnequalLabelsInitializer(LabelsInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, dist, clabels=None):
 | 
					 | 
				
			||||||
        self.dist = dist
 | 
					 | 
				
			||||||
        self.clabels = clabels or range(len(self.dist))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def distribution(self):
 | 
					 | 
				
			||||||
        return self.dist
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self):
 | 
					 | 
				
			||||||
        targets = list(
 | 
					 | 
				
			||||||
            chain(*[[i] * n for i, n in zip(self.clabels, self.dist)]))
 | 
					 | 
				
			||||||
        return torch.LongTensor(targets)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class EqualLabelsInitializer(LabelsInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, classes, per_class):
 | 
					 | 
				
			||||||
        self.classes = classes
 | 
					 | 
				
			||||||
        self.per_class = per_class
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def distribution(self):
 | 
					 | 
				
			||||||
        return self.classes * [self.per_class]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self):
 | 
					 | 
				
			||||||
        return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Reasonings
 | 
					 | 
				
			||||||
class ReasoningsInitializer:
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        raise NotImplementedError("Subclasses should implement this!")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ZeroReasoningsInitializer(ReasoningsInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, classes, length):
 | 
					 | 
				
			||||||
        self.classes = classes
 | 
					 | 
				
			||||||
        self.length = length
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self):
 | 
					 | 
				
			||||||
        return torch.zeros((self.length, self.classes, 2))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Aliases
 | 
					 | 
				
			||||||
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
 | 
					 | 
				
			||||||
SMI = StratifiedMeanInitializer
 | 
					 | 
				
			||||||
Random = RandomInitializer = UniformInitializer
 | 
					 | 
				
			||||||
Zeros = ZerosInitializer
 | 
					 | 
				
			||||||
Ones = OnesInitializer
 | 
					 | 
				
			||||||
PCA = PcaInitializer
 | 
					 | 
				
			||||||
							
								
								
									
										10
									
								
								prototorch/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								prototorch/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch core"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .competitions import *
 | 
				
			||||||
 | 
					from .components import *
 | 
				
			||||||
 | 
					from .distances import *
 | 
				
			||||||
 | 
					from .initializers import *
 | 
				
			||||||
 | 
					from .losses import *
 | 
				
			||||||
 | 
					from .pooling import *
 | 
				
			||||||
 | 
					from .similarities import *
 | 
				
			||||||
 | 
					from .transforms import *
 | 
				
			||||||
							
								
								
									
										89
									
								
								prototorch/core/competitions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								prototorch/core/competitions.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,89 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch competitions"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def wtac(distances: torch.Tensor, labels: torch.LongTensor):
 | 
				
			||||||
 | 
					    """Winner-Takes-All-Competition.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns the labels corresponding to the winners.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    winning_indices = torch.min(distances, dim=1).indices
 | 
				
			||||||
 | 
					    winning_labels = labels[winning_indices].squeeze()
 | 
				
			||||||
 | 
					    return winning_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def knnc(distances: torch.Tensor, labels: torch.LongTensor, k: int = 1):
 | 
				
			||||||
 | 
					    """K-Nearest-Neighbors-Competition.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns the labels corresponding to the winners.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    winning_indices = torch.topk(-distances, k=k, dim=1).indices
 | 
				
			||||||
 | 
					    winning_labels = torch.mode(labels[winning_indices], dim=1).values
 | 
				
			||||||
 | 
					    return winning_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
 | 
				
			||||||
 | 
					    """Classification-By-Components Competition.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns probability distributions over the classes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    `detections` must be of shape [batch_size, num_components].
 | 
				
			||||||
 | 
					    `reasonings` must be of shape [num_components, num_classes, 2].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    A, B = reasonings.permute(2, 1, 0).clamp(0, 1)
 | 
				
			||||||
 | 
					    pk = A
 | 
				
			||||||
 | 
					    nk = (1 - A) * B
 | 
				
			||||||
 | 
					    numerator = (detections @ (pk - nk).T) + nk.sum(1)
 | 
				
			||||||
 | 
					    probs = numerator / (pk + nk).sum(1)
 | 
				
			||||||
 | 
					    return probs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WTAC(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Winner-Takes-All-Competition Layer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Thin wrapper over the `wtac` function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def forward(self, distances, labels):
 | 
				
			||||||
 | 
					        return wtac(distances, labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LTAC(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Loser-Takes-All-Competition Layer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Thin wrapper over the `wtac` function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def forward(self, probs, labels):
 | 
				
			||||||
 | 
					        return wtac(-1.0 * probs, labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KNNC(torch.nn.Module):
 | 
				
			||||||
 | 
					    """K-Nearest-Neighbors-Competition.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Thin wrapper over the `knnc` function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, k=1, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.k = k
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, distances, labels):
 | 
				
			||||||
 | 
					        return knnc(distances, labels, k=self.k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def extra_repr(self):
 | 
				
			||||||
 | 
					        return f"k: {self.k}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CBCC(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Classification-By-Components Competition.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Thin wrapper over the `cbcc` function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def forward(self, detections, reasonings):
 | 
				
			||||||
 | 
					        return cbcc(detections, reasonings)
 | 
				
			||||||
							
								
								
									
										370
									
								
								prototorch/core/components.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										370
									
								
								prototorch/core/components.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,370 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch components"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import inspect
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..utils import parse_distribution
 | 
				
			||||||
 | 
					from .initializers import (
 | 
				
			||||||
 | 
					    AbstractClassAwareCompInitializer,
 | 
				
			||||||
 | 
					    AbstractComponentsInitializer,
 | 
				
			||||||
 | 
					    AbstractLabelsInitializer,
 | 
				
			||||||
 | 
					    AbstractReasoningsInitializer,
 | 
				
			||||||
 | 
					    LabelsInitializer,
 | 
				
			||||||
 | 
					    PurePositiveReasoningsInitializer,
 | 
				
			||||||
 | 
					    RandomReasoningsInitializer,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def validate_initializer(initializer, instanceof):
 | 
				
			||||||
 | 
					    """Check if the initializer is valid."""
 | 
				
			||||||
 | 
					    if not isinstance(initializer, instanceof):
 | 
				
			||||||
 | 
					        emsg = f"`initializer` has to be an instance " \
 | 
				
			||||||
 | 
					            f"of some subtype of {instanceof}. " \
 | 
				
			||||||
 | 
					            f"You have provided: {initializer} instead. "
 | 
				
			||||||
 | 
					        helpmsg = ""
 | 
				
			||||||
 | 
					        if inspect.isclass(initializer):
 | 
				
			||||||
 | 
					            helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \
 | 
				
			||||||
 | 
					                f"with the brackets instead of just {initializer.__name__}?"
 | 
				
			||||||
 | 
					        raise TypeError(emsg + helpmsg)
 | 
				
			||||||
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gencat(ins, attr, init, *iargs, **ikwargs):
 | 
				
			||||||
 | 
					    """Generate new items and concatenate with existing items."""
 | 
				
			||||||
 | 
					    new_items = init.generate(*iargs, **ikwargs)
 | 
				
			||||||
 | 
					    if hasattr(ins, attr):
 | 
				
			||||||
 | 
					        items = torch.cat([getattr(ins, attr), new_items])
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        items = new_items
 | 
				
			||||||
 | 
					    return items, new_items
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def removeind(ins, attr, indices):
 | 
				
			||||||
 | 
					    """Remove items at specified indices."""
 | 
				
			||||||
 | 
					    mask = torch.ones(len(ins), dtype=torch.bool)
 | 
				
			||||||
 | 
					    mask[indices] = False
 | 
				
			||||||
 | 
					    items = getattr(ins, attr)[mask]
 | 
				
			||||||
 | 
					    return items, mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_cikwargs(init, distribution):
 | 
				
			||||||
 | 
					    """Return appropriate key-word arguments for a component initializer."""
 | 
				
			||||||
 | 
					    if isinstance(init, AbstractClassAwareCompInitializer):
 | 
				
			||||||
 | 
					        cikwargs = dict(distribution=distribution)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					        num_components = sum(distribution.values())
 | 
				
			||||||
 | 
					        cikwargs = dict(num_components=num_components)
 | 
				
			||||||
 | 
					    return cikwargs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractComponents(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Abstract class for all components modules."""
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def num_components(self):
 | 
				
			||||||
 | 
					        """Current number of components."""
 | 
				
			||||||
 | 
					        return len(self._components)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def components(self):
 | 
				
			||||||
 | 
					        """Detached Tensor containing the components."""
 | 
				
			||||||
 | 
					        return self._components.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_components(self, components):
 | 
				
			||||||
 | 
					        self.register_parameter("_components", Parameter(components))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def extra_repr(self):
 | 
				
			||||||
 | 
					        return f"components: (shape: {tuple(self._components.shape)})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return self.num_components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Components(AbstractComponents):
 | 
				
			||||||
 | 
					    """A set of adaptable Tensors."""
 | 
				
			||||||
 | 
					    def __init__(self, num_components: int,
 | 
				
			||||||
 | 
					                 initializer: AbstractComponentsInitializer):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.add_components(num_components, initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_components(self, num_components: int,
 | 
				
			||||||
 | 
					                       initializer: AbstractComponentsInitializer):
 | 
				
			||||||
 | 
					        """Generate and add new components."""
 | 
				
			||||||
 | 
					        assert validate_initializer(initializer, AbstractComponentsInitializer)
 | 
				
			||||||
 | 
					        _components, new_components = gencat(self, "_components", initializer,
 | 
				
			||||||
 | 
					                                             num_components)
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        return new_components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_components(self, indices):
 | 
				
			||||||
 | 
					        """Remove components at specified indices."""
 | 
				
			||||||
 | 
					        _components, mask = removeind(self, "_components", indices)
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        """Simply return the components parameter Tensor."""
 | 
				
			||||||
 | 
					        return self._components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractLabels(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Abstract class for all labels modules."""
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def labels(self):
 | 
				
			||||||
 | 
					        return self._labels.cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def num_labels(self):
 | 
				
			||||||
 | 
					        return len(self._labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def unique_labels(self):
 | 
				
			||||||
 | 
					        return torch.unique(self._labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def num_unique(self):
 | 
				
			||||||
 | 
					        return len(self.unique_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def distribution(self):
 | 
				
			||||||
 | 
					        unique, counts = torch.unique(self._labels,
 | 
				
			||||||
 | 
					                                      sorted=True,
 | 
				
			||||||
 | 
					                                      return_counts=True)
 | 
				
			||||||
 | 
					        return dict(zip(unique.tolist(), counts.tolist()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_labels(self, labels):
 | 
				
			||||||
 | 
					        self.register_buffer("_labels", labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def extra_repr(self):
 | 
				
			||||||
 | 
					        r = f"num_labels: {self.num_labels}, num_unique: {self.num_unique}"
 | 
				
			||||||
 | 
					        if len(self.distribution) < 11:  # avoid lengthy representations
 | 
				
			||||||
 | 
					            d = self.distribution
 | 
				
			||||||
 | 
					            unique, counts = list(d.keys()), list(d.values())
 | 
				
			||||||
 | 
					            r += f", unique: {unique}, counts: {counts}"
 | 
				
			||||||
 | 
					        return r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return self.num_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Labels(AbstractLabels):
 | 
				
			||||||
 | 
					    """A set of standalone labels."""
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 distribution: Union[dict, list, tuple],
 | 
				
			||||||
 | 
					                 initializer: AbstractLabelsInitializer = LabelsInitializer()):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.add_labels(distribution, initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_labels(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        distribution: Union[dict, tuple, list],
 | 
				
			||||||
 | 
					        initializer: AbstractLabelsInitializer = LabelsInitializer()):
 | 
				
			||||||
 | 
					        """Generate and add new labels."""
 | 
				
			||||||
 | 
					        assert validate_initializer(initializer, AbstractLabelsInitializer)
 | 
				
			||||||
 | 
					        _labels, new_labels = gencat(self, "_labels", initializer,
 | 
				
			||||||
 | 
					                                     distribution)
 | 
				
			||||||
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
 | 
					        return new_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_labels(self, indices):
 | 
				
			||||||
 | 
					        """Remove labels at specified indices."""
 | 
				
			||||||
 | 
					        _labels, mask = removeind(self, "_labels", indices)
 | 
				
			||||||
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        """Simply return the labels."""
 | 
				
			||||||
 | 
					        return self._labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LabeledComponents(AbstractComponents):
 | 
				
			||||||
 | 
					    """A set of adaptable components and corresponding unadaptable labels."""
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        distribution: Union[dict, list, tuple],
 | 
				
			||||||
 | 
					        components_initializer: AbstractComponentsInitializer,
 | 
				
			||||||
 | 
					        labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.add_components(distribution, components_initializer,
 | 
				
			||||||
 | 
					                            labels_initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def distribution(self):
 | 
				
			||||||
 | 
					        unique, counts = torch.unique(self._labels,
 | 
				
			||||||
 | 
					                                      sorted=True,
 | 
				
			||||||
 | 
					                                      return_counts=True)
 | 
				
			||||||
 | 
					        return dict(zip(unique.tolist(), counts.tolist()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def num_classes(self):
 | 
				
			||||||
 | 
					        return len(self.distribution.keys())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def labels(self):
 | 
				
			||||||
 | 
					        """Tensor containing the component labels."""
 | 
				
			||||||
 | 
					        return self._labels.cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_labels(self, labels):
 | 
				
			||||||
 | 
					        self.register_buffer("_labels", labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_components(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        distribution,
 | 
				
			||||||
 | 
					        components_initializer,
 | 
				
			||||||
 | 
					        labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
 | 
				
			||||||
 | 
					        """Generate and add new components and labels."""
 | 
				
			||||||
 | 
					        assert validate_initializer(components_initializer,
 | 
				
			||||||
 | 
					                                    AbstractComponentsInitializer)
 | 
				
			||||||
 | 
					        assert validate_initializer(labels_initializer,
 | 
				
			||||||
 | 
					                                    AbstractLabelsInitializer)
 | 
				
			||||||
 | 
					        cikwargs = get_cikwargs(components_initializer, distribution)
 | 
				
			||||||
 | 
					        _components, new_components = gencat(self, "_components",
 | 
				
			||||||
 | 
					                                             components_initializer,
 | 
				
			||||||
 | 
					                                             **cikwargs)
 | 
				
			||||||
 | 
					        _labels, new_labels = gencat(self, "_labels", labels_initializer,
 | 
				
			||||||
 | 
					                                     distribution)
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
 | 
					        return new_components, new_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_components(self, indices):
 | 
				
			||||||
 | 
					        """Remove components and labels at specified indices."""
 | 
				
			||||||
 | 
					        _components, mask = removeind(self, "_components", indices)
 | 
				
			||||||
 | 
					        _labels, mask = removeind(self, "_labels", indices)
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        """Simply return the components parameter Tensor and labels."""
 | 
				
			||||||
 | 
					        return self._components, self._labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Reasonings(torch.nn.Module):
 | 
				
			||||||
 | 
					    """A set of standalone reasoning matrices.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The `reasonings` tensor is of shape [num_components, num_classes, 2].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        distribution: Union[dict, list, tuple],
 | 
				
			||||||
 | 
					        initializer:
 | 
				
			||||||
 | 
					        AbstractReasoningsInitializer = RandomReasoningsInitializer()):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def num_classes(self):
 | 
				
			||||||
 | 
					        return self._reasonings.shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def reasonings(self):
 | 
				
			||||||
 | 
					        """Tensor containing the reasoning matrices."""
 | 
				
			||||||
 | 
					        return self._reasonings.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_reasonings(self, reasonings):
 | 
				
			||||||
 | 
					        self.register_buffer("_reasonings", reasonings)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_reasonings(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        distribution: Union[dict, list, tuple],
 | 
				
			||||||
 | 
					        initializer:
 | 
				
			||||||
 | 
					        AbstractReasoningsInitializer = RandomReasoningsInitializer()):
 | 
				
			||||||
 | 
					        """Generate and add new reasonings."""
 | 
				
			||||||
 | 
					        assert validate_initializer(initializer, AbstractReasoningsInitializer)
 | 
				
			||||||
 | 
					        _reasonings, new_reasonings = gencat(self, "_reasonings", initializer,
 | 
				
			||||||
 | 
					                                             distribution)
 | 
				
			||||||
 | 
					        self._register_reasonings(_reasonings)
 | 
				
			||||||
 | 
					        return new_reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_reasonings(self, indices):
 | 
				
			||||||
 | 
					        """Remove reasonings at specified indices."""
 | 
				
			||||||
 | 
					        _reasonings, mask = removeind(self, "_reasonings", indices)
 | 
				
			||||||
 | 
					        self._register_reasonings(_reasonings)
 | 
				
			||||||
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        """Simply return the reasonings."""
 | 
				
			||||||
 | 
					        return self._reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ReasoningComponents(AbstractComponents):
 | 
				
			||||||
 | 
					    r"""A set of components and a corresponding adapatable reasoning matrices.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Every component has its own reasoning matrix.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    A reasoning matrix is an Nx2 matrix, where N is the number of classes. The
 | 
				
			||||||
 | 
					    first element is called positive reasoning :math:`p`, the second negative
 | 
				
			||||||
 | 
					    reasoning :math:`n`. A components can reason in favour (positive) of a
 | 
				
			||||||
 | 
					    class, against (negative) a class or not at all (neutral).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
 | 
				
			||||||
 | 
					    \leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
 | 
				
			||||||
 | 
					    three element probability distribution.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        distribution: Union[dict, list, tuple],
 | 
				
			||||||
 | 
					        components_initializer: AbstractComponentsInitializer,
 | 
				
			||||||
 | 
					        reasonings_initializer:
 | 
				
			||||||
 | 
					        AbstractReasoningsInitializer = PurePositiveReasoningsInitializer()):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.add_components(distribution, components_initializer,
 | 
				
			||||||
 | 
					                            reasonings_initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def num_classes(self):
 | 
				
			||||||
 | 
					        return self._reasonings.shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def reasonings(self):
 | 
				
			||||||
 | 
					        """Tensor containing the reasoning matrices."""
 | 
				
			||||||
 | 
					        return self._reasonings.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def reasoning_matrices(self):
 | 
				
			||||||
 | 
					        """Reasoning matrices for each class."""
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
 | 
				
			||||||
 | 
					            pk = A
 | 
				
			||||||
 | 
					            nk = (1 - pk) * B
 | 
				
			||||||
 | 
					            ik = 1 - pk - nk
 | 
				
			||||||
 | 
					            matrices = torch.stack([pk, nk, ik], dim=-1).permute(1, 2, 0)
 | 
				
			||||||
 | 
					        return matrices.cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_reasonings(self, reasonings):
 | 
				
			||||||
 | 
					        self.register_parameter("_reasonings", Parameter(reasonings))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_components(self, distribution, components_initializer,
 | 
				
			||||||
 | 
					                       reasonings_initializer: AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					        """Generate and add new components and reasonings."""
 | 
				
			||||||
 | 
					        assert validate_initializer(components_initializer,
 | 
				
			||||||
 | 
					                                    AbstractComponentsInitializer)
 | 
				
			||||||
 | 
					        assert validate_initializer(reasonings_initializer,
 | 
				
			||||||
 | 
					                                    AbstractReasoningsInitializer)
 | 
				
			||||||
 | 
					        cikwargs = get_cikwargs(components_initializer, distribution)
 | 
				
			||||||
 | 
					        _components, new_components = gencat(self, "_components",
 | 
				
			||||||
 | 
					                                             components_initializer,
 | 
				
			||||||
 | 
					                                             **cikwargs)
 | 
				
			||||||
 | 
					        _reasonings, new_reasonings = gencat(self, "_reasonings",
 | 
				
			||||||
 | 
					                                             reasonings_initializer,
 | 
				
			||||||
 | 
					                                             distribution)
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        self._register_reasonings(_reasonings)
 | 
				
			||||||
 | 
					        return new_components, new_reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_components(self, indices):
 | 
				
			||||||
 | 
					        """Remove components and reasonings at specified indices."""
 | 
				
			||||||
 | 
					        _components, mask = removeind(self, "_components", indices)
 | 
				
			||||||
 | 
					        _reasonings, mask = removeind(self, "_reasonings", indices)
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        self._register_reasonings(_reasonings)
 | 
				
			||||||
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        """Simply return the components and reasonings."""
 | 
				
			||||||
 | 
					        return self._components, self._reasonings
 | 
				
			||||||
							
								
								
									
										98
									
								
								prototorch/core/distances.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								prototorch/core/distances.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,98 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch distances"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def squared_euclidean_distance(x, y):
 | 
				
			||||||
 | 
					    r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Compute :math:`{\langle \bm x - \bm y \rangle}_2`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    **Alias:**
 | 
				
			||||||
 | 
					    ``prototorch.functions.distances.sed``
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
 | 
				
			||||||
 | 
					    expanded_x = x.unsqueeze(dim=1)
 | 
				
			||||||
 | 
					    batchwise_difference = y - expanded_x
 | 
				
			||||||
 | 
					    differences_raised = torch.pow(batchwise_difference, 2)
 | 
				
			||||||
 | 
					    distances = torch.sum(differences_raised, axis=2)
 | 
				
			||||||
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def euclidean_distance(x, y):
 | 
				
			||||||
 | 
					    r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :returns: Distance Tensor of shape :math:`X \times Y`
 | 
				
			||||||
 | 
					    :rtype: `torch.tensor`
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
 | 
				
			||||||
 | 
					    distances_raised = squared_euclidean_distance(x, y)
 | 
				
			||||||
 | 
					    distances = torch.sqrt(distances_raised)
 | 
				
			||||||
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def euclidean_distance_v2(x, y):
 | 
				
			||||||
 | 
					    x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
 | 
				
			||||||
 | 
					    diff = y - x.unsqueeze(1)
 | 
				
			||||||
 | 
					    pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
 | 
				
			||||||
 | 
					    # Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
 | 
				
			||||||
 | 
					    # batch diagonal. See:
 | 
				
			||||||
 | 
					    # https://pytorch.org/docs/stable/generated/torch.diagonal.html
 | 
				
			||||||
 | 
					    distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
 | 
				
			||||||
 | 
					    # print(f"{diff.shape=}")  # (nx, ny, ndim)
 | 
				
			||||||
 | 
					    # print(f"{pairwise_distances.shape=}")  # (nx, ny, ny)
 | 
				
			||||||
 | 
					    # print(f"{distances.shape=}")  # (nx, ny)
 | 
				
			||||||
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def lpnorm_distance(x, y, p):
 | 
				
			||||||
 | 
					    r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
 | 
				
			||||||
 | 
					    Also known as Minkowski distance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Compute :math:`{\| \bm x - \bm y \|}_p`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Calls ``torch.cdist``
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :param p: p parameter of the lp norm
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
 | 
				
			||||||
 | 
					    distances = torch.cdist(x, y, p=p)
 | 
				
			||||||
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def omega_distance(x, y, omega):
 | 
				
			||||||
 | 
					    r"""Omega distance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :param `torch.tensor` omega: Two dimensional matrix
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
 | 
				
			||||||
 | 
					    projected_x = x @ omega
 | 
				
			||||||
 | 
					    projected_y = y @ omega
 | 
				
			||||||
 | 
					    distances = squared_euclidean_distance(projected_x, projected_y)
 | 
				
			||||||
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def lomega_distance(x, y, omegas):
 | 
				
			||||||
 | 
					    r"""Localized Omega distance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :param `torch.tensor` omegas: Three dimensional matrix
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
 | 
				
			||||||
 | 
					    projected_x = x @ omegas
 | 
				
			||||||
 | 
					    projected_y = torch.diagonal(y @ omegas).T
 | 
				
			||||||
 | 
					    expanded_y = torch.unsqueeze(projected_y, dim=1)
 | 
				
			||||||
 | 
					    batchwise_difference = expanded_y - projected_x
 | 
				
			||||||
 | 
					    differences_squared = batchwise_difference**2
 | 
				
			||||||
 | 
					    distances = torch.sum(differences_squared, dim=2)
 | 
				
			||||||
 | 
					    distances = distances.permute(1, 0)
 | 
				
			||||||
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases
 | 
				
			||||||
 | 
					sed = squared_euclidean_distance
 | 
				
			||||||
							
								
								
									
										494
									
								
								prototorch/core/initializers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										494
									
								
								prototorch/core/initializers.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,494 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch code initializers"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					from abc import ABC, abstractmethod
 | 
				
			||||||
 | 
					from collections.abc import Iterable
 | 
				
			||||||
 | 
					from typing import (
 | 
				
			||||||
 | 
					    Callable,
 | 
				
			||||||
 | 
					    Type,
 | 
				
			||||||
 | 
					    Union,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..utils import parse_data_arg, parse_distribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Components
 | 
				
			||||||
 | 
					class AbstractComponentsInitializer(ABC):
 | 
				
			||||||
 | 
					    """Abstract class for all components initializers."""
 | 
				
			||||||
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LiteralCompInitializer(AbstractComponentsInitializer):
 | 
				
			||||||
 | 
					    """'Generate' the provided components.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Use this to 'generate' pre-initialized components elsewhere.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, components):
 | 
				
			||||||
 | 
					        self.components = components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, num_components: int = 0):
 | 
				
			||||||
 | 
					        """Ignore `num_components` and simply return `self.components`."""
 | 
				
			||||||
 | 
					        if not isinstance(self.components, torch.Tensor):
 | 
				
			||||||
 | 
					            wmsg = f"Converting components to {torch.Tensor}..."
 | 
				
			||||||
 | 
					            warnings.warn(wmsg)
 | 
				
			||||||
 | 
					            self.components = torch.Tensor(self.components)
 | 
				
			||||||
 | 
					        return self.components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ShapeAwareCompInitializer(AbstractComponentsInitializer):
 | 
				
			||||||
 | 
					    """Abstract class for all dimension-aware components initializers."""
 | 
				
			||||||
 | 
					    def __init__(self, shape: Union[Iterable, int]):
 | 
				
			||||||
 | 
					        if isinstance(shape, Iterable):
 | 
				
			||||||
 | 
					            self.component_shape = tuple(shape)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.component_shape = (shape, )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ZerosCompInitializer(ShapeAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Generate zeros corresponding to the components shape."""
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        components = torch.zeros((num_components, ) + self.component_shape)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OnesCompInitializer(ShapeAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Generate ones corresponding to the components shape."""
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        components = torch.ones((num_components, ) + self.component_shape)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FillValueCompInitializer(OnesCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components with the provided `fill_value`."""
 | 
				
			||||||
 | 
					    def __init__(self, shape, fill_value: float = 1.0):
 | 
				
			||||||
 | 
					        super().__init__(shape)
 | 
				
			||||||
 | 
					        self.fill_value = fill_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        ones = super().generate(num_components)
 | 
				
			||||||
 | 
					        components = ones.fill_(self.fill_value)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UniformCompInitializer(OnesCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components by sampling from a continuous uniform distribution."""
 | 
				
			||||||
 | 
					    def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
 | 
				
			||||||
 | 
					        super().__init__(shape)
 | 
				
			||||||
 | 
					        self.minimum = minimum
 | 
				
			||||||
 | 
					        self.maximum = maximum
 | 
				
			||||||
 | 
					        self.scale = scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        ones = super().generate(num_components)
 | 
				
			||||||
 | 
					        components = self.scale * ones.uniform_(self.minimum, self.maximum)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RandomNormalCompInitializer(OnesCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components by sampling from a standard normal distribution."""
 | 
				
			||||||
 | 
					    def __init__(self, shape, shift=0.0, scale=1.0):
 | 
				
			||||||
 | 
					        super().__init__(shape)
 | 
				
			||||||
 | 
					        self.shift = shift
 | 
				
			||||||
 | 
					        self.scale = scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        ones = super().generate(num_components)
 | 
				
			||||||
 | 
					        components = self.scale * (torch.randn_like(ones) + self.shift)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
 | 
				
			||||||
 | 
					    """Abstract class for all data-aware components initializers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Components generated by data-aware components initializers inherit the shape
 | 
				
			||||||
 | 
					    of the provided data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    `data` has to be a torch tensor.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 data: torch.Tensor,
 | 
				
			||||||
 | 
					                 noise: float = 0.0,
 | 
				
			||||||
 | 
					                 transform: Callable = torch.nn.Identity()):
 | 
				
			||||||
 | 
					        self.data = data
 | 
				
			||||||
 | 
					        self.noise = noise
 | 
				
			||||||
 | 
					        self.transform = transform
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_end_hook(self, samples):
 | 
				
			||||||
 | 
					        drift = torch.rand_like(samples) * self.noise
 | 
				
			||||||
 | 
					        components = self.transform(samples + drift)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					        return self.generate_end_hook(...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __del__(self):
 | 
				
			||||||
 | 
					        del self.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
 | 
				
			||||||
 | 
					    """'Generate' the components from the provided data."""
 | 
				
			||||||
 | 
					    def generate(self, num_components: int = 0):
 | 
				
			||||||
 | 
					        """Ignore `num_components` and simply return transformed `self.data`."""
 | 
				
			||||||
 | 
					        components = self.generate_end_hook(self.data)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SelectionCompInitializer(AbstractDataAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components by uniformly sampling from the provided data."""
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        indices = torch.LongTensor(num_components).random_(0, len(self.data))
 | 
				
			||||||
 | 
					        samples = self.data[indices]
 | 
				
			||||||
 | 
					        components = self.generate_end_hook(samples)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MeanCompInitializer(AbstractDataAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components by computing the mean of the provided data."""
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        mean = self.data.mean(dim=0)
 | 
				
			||||||
 | 
					        repeat_dim = [num_components] + [1] * len(mean.shape)
 | 
				
			||||||
 | 
					        samples = mean.repeat(repeat_dim)
 | 
				
			||||||
 | 
					        components = self.generate_end_hook(samples)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
 | 
				
			||||||
 | 
					    """Abstract class for all class-aware components initializers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Components generated by class-aware components initializers inherit the shape
 | 
				
			||||||
 | 
					    of the provided data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    `data` could be a torch Dataset or DataLoader or a list/tuple of data and
 | 
				
			||||||
 | 
					    target tensors.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 data,
 | 
				
			||||||
 | 
					                 noise: float = 0.0,
 | 
				
			||||||
 | 
					                 transform: Callable = torch.nn.Identity()):
 | 
				
			||||||
 | 
					        self.data, self.targets = parse_data_arg(data)
 | 
				
			||||||
 | 
					        self.noise = noise
 | 
				
			||||||
 | 
					        self.transform = transform
 | 
				
			||||||
 | 
					        self.clabels = torch.unique(self.targets).int().tolist()
 | 
				
			||||||
 | 
					        self.num_classes = len(self.clabels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_end_hook(self, samples):
 | 
				
			||||||
 | 
					        drift = torch.rand_like(samples) * self.noise
 | 
				
			||||||
 | 
					        components = self.transform(samples + drift)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					        return self.generate_end_hook(...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __del__(self):
 | 
				
			||||||
 | 
					        del self.data
 | 
				
			||||||
 | 
					        del self.targets
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
 | 
				
			||||||
 | 
					    """'Generate' components from provided data and requested distribution."""
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        """Ignore `distribution` and simply return transformed `self.data`."""
 | 
				
			||||||
 | 
					        components = self.generate_end_hook(self.data)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Abstract class for all stratified components initializers."""
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					        components = torch.tensor([])
 | 
				
			||||||
 | 
					        for k, v in distribution.items():
 | 
				
			||||||
 | 
					            stratified_data = self.data[self.targets == k]
 | 
				
			||||||
 | 
					            initializer = self.subinit_type(
 | 
				
			||||||
 | 
					                stratified_data,
 | 
				
			||||||
 | 
					                noise=self.noise,
 | 
				
			||||||
 | 
					                transform=self.transform,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            samples = initializer.generate(num_components=v)
 | 
				
			||||||
 | 
					            components = torch.cat([components, samples])
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components using stratified sampling from the provided data."""
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def subinit_type(self):
 | 
				
			||||||
 | 
					        return SelectionCompInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components at stratified means of the provided data."""
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def subinit_type(self):
 | 
				
			||||||
 | 
					        return MeanCompInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Labels
 | 
				
			||||||
 | 
					class AbstractLabelsInitializer(ABC):
 | 
				
			||||||
 | 
					    """Abstract class for all labels initializers."""
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LiteralLabelsInitializer(AbstractLabelsInitializer):
 | 
				
			||||||
 | 
					    """'Generate' the provided labels.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Use this to 'generate' pre-initialized labels elsewhere.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, labels):
 | 
				
			||||||
 | 
					        self.labels = labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        """Ignore `distribution` and simply return `self.labels`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Convert to long tensor, if necessary.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        labels = self.labels
 | 
				
			||||||
 | 
					        if not isinstance(labels, torch.LongTensor):
 | 
				
			||||||
 | 
					            wmsg = f"Converting labels to {torch.LongTensor}..."
 | 
				
			||||||
 | 
					            warnings.warn(wmsg)
 | 
				
			||||||
 | 
					            labels = torch.LongTensor(labels)
 | 
				
			||||||
 | 
					        return labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DataAwareLabelsInitializer(AbstractLabelsInitializer):
 | 
				
			||||||
 | 
					    """'Generate' the labels from a torch Dataset."""
 | 
				
			||||||
 | 
					    def __init__(self, data):
 | 
				
			||||||
 | 
					        self.data, self.targets = parse_data_arg(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        """Ignore `num_components` and simply return `self.targets`."""
 | 
				
			||||||
 | 
					        return self.targets
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LabelsInitializer(AbstractLabelsInitializer):
 | 
				
			||||||
 | 
					    """Generate labels from `distribution`."""
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					        labels_list = []
 | 
				
			||||||
 | 
					        for k, v in distribution.items():
 | 
				
			||||||
 | 
					            labels_list.extend([k] * v)
 | 
				
			||||||
 | 
					        labels = torch.LongTensor(labels_list)
 | 
				
			||||||
 | 
					        return labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OneHotLabelsInitializer(LabelsInitializer):
 | 
				
			||||||
 | 
					    """Generate one-hot-encoded labels from `distribution`."""
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					        num_classes = len(distribution.keys())
 | 
				
			||||||
 | 
					        # this breaks if class labels are not [0,...,nclasses]
 | 
				
			||||||
 | 
					        labels = torch.eye(num_classes)[super().generate(distribution)]
 | 
				
			||||||
 | 
					        return labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Reasonings
 | 
				
			||||||
 | 
					class AbstractReasoningsInitializer(ABC):
 | 
				
			||||||
 | 
					    """Abstract class for all reasonings initializers."""
 | 
				
			||||||
 | 
					    def __init__(self, components_first: bool = True):
 | 
				
			||||||
 | 
					        self.components_first = components_first
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def compute_shape(self, distribution):
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					        num_components = sum(distribution.values())
 | 
				
			||||||
 | 
					        num_classes = len(distribution.keys())
 | 
				
			||||||
 | 
					        return (num_components, num_classes, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_end_hook(self, reasonings):
 | 
				
			||||||
 | 
					        if not self.components_first:
 | 
				
			||||||
 | 
					            reasonings = reasonings.permute(2, 1, 0)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					        return self.generate_end_hook(...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					    """'Generate' the provided reasonings.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Use this to 'generate' pre-initialized reasonings elsewhere.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, reasonings, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.reasonings = reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        """Ignore `distributuion` and simply return self.reasonings."""
 | 
				
			||||||
 | 
					        reasonings = self.reasonings
 | 
				
			||||||
 | 
					        if not isinstance(reasonings, torch.Tensor):
 | 
				
			||||||
 | 
					            wmsg = f"Converting reasonings to {torch.Tensor}..."
 | 
				
			||||||
 | 
					            warnings.warn(wmsg)
 | 
				
			||||||
 | 
					            reasonings = torch.Tensor(reasonings)
 | 
				
			||||||
 | 
					        reasonings = self.generate_end_hook(reasonings)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					    """Reasonings are all initialized with zeros."""
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        shape = self.compute_shape(distribution)
 | 
				
			||||||
 | 
					        reasonings = torch.zeros(*shape)
 | 
				
			||||||
 | 
					        reasonings = self.generate_end_hook(reasonings)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OnesReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					    """Reasonings are all initialized with ones."""
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        shape = self.compute_shape(distribution)
 | 
				
			||||||
 | 
					        reasonings = torch.ones(*shape)
 | 
				
			||||||
 | 
					        reasonings = self.generate_end_hook(reasonings)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RandomReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					    """Reasonings are randomly initialized."""
 | 
				
			||||||
 | 
					    def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.minimum = minimum
 | 
				
			||||||
 | 
					        self.maximum = maximum
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        shape = self.compute_shape(distribution)
 | 
				
			||||||
 | 
					        reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
 | 
				
			||||||
 | 
					        reasonings = self.generate_end_hook(reasonings)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					    """Each component reasons positively for exactly one class."""
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        num_components, num_classes, _ = self.compute_shape(distribution)
 | 
				
			||||||
 | 
					        A = OneHotLabelsInitializer().generate(distribution)
 | 
				
			||||||
 | 
					        B = torch.zeros(num_components, num_classes)
 | 
				
			||||||
 | 
					        reasonings = torch.stack([A, B], dim=-1)
 | 
				
			||||||
 | 
					        reasonings = self.generate_end_hook(reasonings)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Transforms
 | 
				
			||||||
 | 
					class AbstractTransformInitializer(ABC):
 | 
				
			||||||
 | 
					    """Abstract class for all transform initializers."""
 | 
				
			||||||
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractLinearTransformInitializer(AbstractTransformInitializer):
 | 
				
			||||||
 | 
					    """Abstract class for all linear transform initializers."""
 | 
				
			||||||
 | 
					    def __init__(self, out_dim_first: bool = False):
 | 
				
			||||||
 | 
					        self.out_dim_first = out_dim_first
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_end_hook(self, weights):
 | 
				
			||||||
 | 
					        if self.out_dim_first:
 | 
				
			||||||
 | 
					            weights = weights.permute(1, 0)
 | 
				
			||||||
 | 
					        return weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, in_dim: int, out_dim: int):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					        return self.generate_end_hook(...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
 | 
				
			||||||
 | 
					    """Initialize a matrix with zeros."""
 | 
				
			||||||
 | 
					    def generate(self, in_dim: int, out_dim: int):
 | 
				
			||||||
 | 
					        weights = torch.zeros(in_dim, out_dim)
 | 
				
			||||||
 | 
					        return self.generate_end_hook(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
 | 
				
			||||||
 | 
					    """Initialize a matrix with ones."""
 | 
				
			||||||
 | 
					    def generate(self, in_dim: int, out_dim: int):
 | 
				
			||||||
 | 
					        weights = torch.ones(in_dim, out_dim)
 | 
				
			||||||
 | 
					        return self.generate_end_hook(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EyeTransformInitializer(AbstractLinearTransformInitializer):
 | 
				
			||||||
 | 
					    """Initialize a matrix with the largest possible identity matrix."""
 | 
				
			||||||
 | 
					    def generate(self, in_dim: int, out_dim: int):
 | 
				
			||||||
 | 
					        weights = torch.zeros(in_dim, out_dim)
 | 
				
			||||||
 | 
					        I = torch.eye(min(in_dim, out_dim))
 | 
				
			||||||
 | 
					        weights[:I.shape[0], :I.shape[1]] = I
 | 
				
			||||||
 | 
					        return self.generate_end_hook(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
 | 
				
			||||||
 | 
					    """Abstract class for all data-aware linear transform initializers."""
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 data: torch.Tensor,
 | 
				
			||||||
 | 
					                 noise: float = 0.0,
 | 
				
			||||||
 | 
					                 transform: Callable = torch.nn.Identity()):
 | 
				
			||||||
 | 
					        self.data = data
 | 
				
			||||||
 | 
					        self.noise = noise
 | 
				
			||||||
 | 
					        self.transform = transform
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_end_hook(self, weights: torch.Tensor):
 | 
				
			||||||
 | 
					        drift = torch.rand_like(weights) * self.noise
 | 
				
			||||||
 | 
					        weights = self.transform(weights + drift)
 | 
				
			||||||
 | 
					        if self.out_dim_first:
 | 
				
			||||||
 | 
					            weights = weights.permute(1, 0)
 | 
				
			||||||
 | 
					        return weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
 | 
				
			||||||
 | 
					    """Initialize a matrix with Eigenvectors from the data."""
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, in_dim: int, out_dim: int):
 | 
				
			||||||
 | 
					        _, _, weights = torch.pca_lowrank(self.data, q=out_dim)
 | 
				
			||||||
 | 
					        return self.generate_end_hook(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases - Components
 | 
				
			||||||
 | 
					CACI = ClassAwareCompInitializer
 | 
				
			||||||
 | 
					DACI = DataAwareCompInitializer
 | 
				
			||||||
 | 
					FVCI = FillValueCompInitializer
 | 
				
			||||||
 | 
					LCI = LiteralCompInitializer
 | 
				
			||||||
 | 
					MCI = MeanCompInitializer
 | 
				
			||||||
 | 
					OCI = OnesCompInitializer
 | 
				
			||||||
 | 
					RNCI = RandomNormalCompInitializer
 | 
				
			||||||
 | 
					SCI = SelectionCompInitializer
 | 
				
			||||||
 | 
					SMCI = StratifiedMeanCompInitializer
 | 
				
			||||||
 | 
					SSCI = StratifiedSelectionCompInitializer
 | 
				
			||||||
 | 
					UCI = UniformCompInitializer
 | 
				
			||||||
 | 
					ZCI = ZerosCompInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases - Labels
 | 
				
			||||||
 | 
					DLI = DataAwareLabelsInitializer
 | 
				
			||||||
 | 
					LI = LabelsInitializer
 | 
				
			||||||
 | 
					LLI = LiteralLabelsInitializer
 | 
				
			||||||
 | 
					OHLI = OneHotLabelsInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases - Reasonings
 | 
				
			||||||
 | 
					LRI = LiteralReasoningsInitializer
 | 
				
			||||||
 | 
					ORI = OnesReasoningsInitializer
 | 
				
			||||||
 | 
					PPRI = PurePositiveReasoningsInitializer
 | 
				
			||||||
 | 
					RRI = RandomReasoningsInitializer
 | 
				
			||||||
 | 
					ZRI = ZerosReasoningsInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases - Transforms
 | 
				
			||||||
 | 
					Eye = EyeTransformInitializer
 | 
				
			||||||
 | 
					OLTI = OnesLinearTransformInitializer
 | 
				
			||||||
 | 
					ZLTI = ZerosLinearTransformInitializer
 | 
				
			||||||
 | 
					PCALTI = PCALinearTransformInitializer
 | 
				
			||||||
@@ -1,8 +1,11 @@
 | 
				
			|||||||
"""ProtoTorch loss functions."""
 | 
					"""ProtoTorch losses"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..nn.activations import get_activation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Helpers
 | 
				
			||||||
def _get_matcher(targets, labels):
 | 
					def _get_matcher(targets, labels):
 | 
				
			||||||
    """Returns a boolean tensor."""
 | 
					    """Returns a boolean tensor."""
 | 
				
			||||||
    matcher = torch.eq(targets.unsqueeze(dim=1), labels)
 | 
					    matcher = torch.eq(targets.unsqueeze(dim=1), labels)
 | 
				
			||||||
@@ -28,6 +31,7 @@ def _get_dp_dm(distances, targets, plabels, with_indices=False):
 | 
				
			|||||||
    return dp.values, dm.values
 | 
					    return dp.values, dm.values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# GLVQ
 | 
				
			||||||
def glvq_loss(distances, target_labels, prototype_labels):
 | 
					def glvq_loss(distances, target_labels, prototype_labels):
 | 
				
			||||||
    """GLVQ loss function with support for one-hot labels."""
 | 
					    """GLVQ loss function with support for one-hot labels."""
 | 
				
			||||||
    dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
 | 
					    dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
 | 
				
			||||||
@@ -92,3 +96,76 @@ def rslvq_loss(probabilities, targets, prototype_labels):
 | 
				
			|||||||
    likelihood = correct / whole
 | 
					    likelihood = correct / whole
 | 
				
			||||||
    log_likelihood = torch.log(likelihood)
 | 
					    log_likelihood = torch.log(likelihood)
 | 
				
			||||||
    return -1.0 * log_likelihood
 | 
					    return -1.0 * log_likelihood
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def margin_loss(y_pred, y_true, margin=0.3):
 | 
				
			||||||
 | 
					    """Compute the margin loss."""
 | 
				
			||||||
 | 
					    dp = torch.sum(y_true * y_pred, dim=-1)
 | 
				
			||||||
 | 
					    dm = torch.max(y_pred - y_true, dim=-1).values
 | 
				
			||||||
 | 
					    return torch.nn.functional.relu(dm - dp + margin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GLVQLoss(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.margin = margin
 | 
				
			||||||
 | 
					        self.squashing = get_activation(squashing)
 | 
				
			||||||
 | 
					        self.beta = torch.tensor(beta)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, outputs, targets):
 | 
				
			||||||
 | 
					        distances, plabels = outputs
 | 
				
			||||||
 | 
					        mu = glvq_loss(distances, targets, prototype_labels=plabels)
 | 
				
			||||||
 | 
					        batch_loss = self.squashing(mu + self.margin, beta=self.beta)
 | 
				
			||||||
 | 
					        return torch.sum(batch_loss, dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MarginLoss(torch.nn.modules.loss._Loss):
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 margin=0.3,
 | 
				
			||||||
 | 
					                 size_average=None,
 | 
				
			||||||
 | 
					                 reduce=None,
 | 
				
			||||||
 | 
					                 reduction="mean"):
 | 
				
			||||||
 | 
					        super().__init__(size_average, reduce, reduction)
 | 
				
			||||||
 | 
					        self.margin = margin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, y_pred, y_true):
 | 
				
			||||||
 | 
					        return margin_loss(y_pred, y_true, self.margin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NeuralGasEnergy(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, lm, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.lm = lm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, d):
 | 
				
			||||||
 | 
					        order = torch.argsort(d, dim=1)
 | 
				
			||||||
 | 
					        ranks = torch.argsort(order, dim=1)
 | 
				
			||||||
 | 
					        cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return cost, order
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def extra_repr(self):
 | 
				
			||||||
 | 
					        return f"lambda: {self.lm}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def _nghood_fn(rankings, lm):
 | 
				
			||||||
 | 
					        return torch.exp(-rankings / lm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GrowingNeuralGasEnergy(NeuralGasEnergy):
 | 
				
			||||||
 | 
					    def __init__(self, topology_layer, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.topology_layer = topology_layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def _nghood_fn(rankings, topology):
 | 
				
			||||||
 | 
					        winner = rankings[:, 0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        weights = torch.zeros_like(rankings, dtype=torch.float)
 | 
				
			||||||
 | 
					        weights[torch.arange(rankings.shape[0]), winner] = 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        neighbours = topology.get_neighbours(winner)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        weights[neighbours] = 0.1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return weights
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
"""ProtoTorch pooling functions."""
 | 
					"""ProtoTorch pooling"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Callable
 | 
					from typing import Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -78,3 +78,27 @@ def stratified_prod_pooling(values: torch.Tensor,
 | 
				
			|||||||
        fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
 | 
					        fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
 | 
				
			||||||
        fill_value=1.0)
 | 
					        fill_value=1.0)
 | 
				
			||||||
    return winning_values
 | 
					    return winning_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StratifiedSumPooling(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Thin wrapper over the `stratified_sum_pooling` function."""
 | 
				
			||||||
 | 
					    def forward(self, values, labels):
 | 
				
			||||||
 | 
					        return stratified_sum_pooling(values, labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StratifiedProdPooling(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Thin wrapper over the `stratified_prod_pooling` function."""
 | 
				
			||||||
 | 
					    def forward(self, values, labels):
 | 
				
			||||||
 | 
					        return stratified_prod_pooling(values, labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StratifiedMinPooling(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Thin wrapper over the `stratified_min_pooling` function."""
 | 
				
			||||||
 | 
					    def forward(self, values, labels):
 | 
				
			||||||
 | 
					        return stratified_min_pooling(values, labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StratifiedMaxPooling(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Thin wrapper over the `stratified_max_pooling` function."""
 | 
				
			||||||
 | 
					    def forward(self, values, labels):
 | 
				
			||||||
 | 
					        return stratified_max_pooling(values, labels)
 | 
				
			||||||
@@ -1,7 +1,19 @@
 | 
				
			|||||||
"""ProtoTorch similarity functions."""
 | 
					"""ProtoTorch similarities."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .distances import euclidean_distance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gaussian(x, variance=1.0):
 | 
				
			||||||
 | 
					    return torch.exp(-(x * x) / (2 * variance))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def euclidean_similarity(x, y, variance=1.0):
 | 
				
			||||||
 | 
					    distances = euclidean_distance(x, y)
 | 
				
			||||||
 | 
					    similarities = gaussian(distances, variance)
 | 
				
			||||||
 | 
					    return similarities
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def cosine_similarity(x, y):
 | 
					def cosine_similarity(x, y):
 | 
				
			||||||
    """Compute the cosine similarity between :math:`x` and :math:`y`.
 | 
					    """Compute the cosine similarity between :math:`x` and :math:`y`.
 | 
				
			||||||
@@ -9,6 +21,7 @@ def cosine_similarity(x, y):
 | 
				
			|||||||
    Expected dimension of x is 2.
 | 
					    Expected dimension of x is 2.
 | 
				
			||||||
    Expected dimension of y is 2.
 | 
					    Expected dimension of y is 2.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
 | 
				
			||||||
    norm_x = x.pow(2).sum(1).sqrt()
 | 
					    norm_x = x.pow(2).sum(1).sqrt()
 | 
				
			||||||
    norm_y = y.pow(2).sum(1).sqrt()
 | 
					    norm_y = y.pow(2).sum(1).sqrt()
 | 
				
			||||||
    norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
 | 
					    norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
 | 
				
			||||||
							
								
								
									
										43
									
								
								prototorch/core/transforms.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								prototorch/core/transforms.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,43 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch transforms"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .initializers import (
 | 
				
			||||||
 | 
					    AbstractLinearTransformInitializer,
 | 
				
			||||||
 | 
					    EyeTransformInitializer,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LinearTransform(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        in_dim: int,
 | 
				
			||||||
 | 
					        out_dim: int,
 | 
				
			||||||
 | 
					        initializer:
 | 
				
			||||||
 | 
					        AbstractLinearTransformInitializer = EyeTransformInitializer()):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.set_weights(in_dim, out_dim, initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def weights(self):
 | 
				
			||||||
 | 
					        return self._weights.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_weights(self, weights):
 | 
				
			||||||
 | 
					        self.register_parameter("_weights", Parameter(weights))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_weights(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        in_dim: int,
 | 
				
			||||||
 | 
					        out_dim: int,
 | 
				
			||||||
 | 
					        initializer:
 | 
				
			||||||
 | 
					        AbstractLinearTransformInitializer = EyeTransformInitializer()):
 | 
				
			||||||
 | 
					        weights = initializer.generate(in_dim, out_dim)
 | 
				
			||||||
 | 
					        self._register_weights(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        return x @ self.weights.T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases
 | 
				
			||||||
 | 
					Omega = LinearTransform
 | 
				
			||||||
@@ -1,6 +1,12 @@
 | 
				
			|||||||
"""ProtoTorch datasets."""
 | 
					"""ProtoTorch datasets"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .abstract import NumpyDataset
 | 
					from .abstract import NumpyDataset
 | 
				
			||||||
from .sklearn import Blobs, Circles, Iris, Moons, Random
 | 
					from .sklearn import (
 | 
				
			||||||
 | 
					    Blobs,
 | 
				
			||||||
 | 
					    Circles,
 | 
				
			||||||
 | 
					    Iris,
 | 
				
			||||||
 | 
					    Moons,
 | 
				
			||||||
 | 
					    Random,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from .spiral import Spiral
 | 
					from .spiral import Spiral
 | 
				
			||||||
from .tecator import Tecator
 | 
					from .tecator import Tecator
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,10 +1,11 @@
 | 
				
			|||||||
"""ProtoTorch abstract dataset classes.
 | 
					"""ProtoTorch abstract dataset classes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
 | 
					Based on `torchvision.VisionDataset` and `torchvision.MNIST`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
For the original code, see:
 | 
					For the original code, see:
 | 
				
			||||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
 | 
					https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
 | 
				
			||||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
 | 
					https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
@@ -12,15 +13,6 @@ import os
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class NumpyDataset(torch.utils.data.TensorDataset):
 | 
					 | 
				
			||||||
    """Create a PyTorch TensorDataset from NumPy arrays."""
 | 
					 | 
				
			||||||
    def __init__(self, data, targets):
 | 
					 | 
				
			||||||
        self.data = torch.Tensor(data)
 | 
					 | 
				
			||||||
        self.targets = torch.LongTensor(targets)
 | 
					 | 
				
			||||||
        tensors = [self.data, self.targets]
 | 
					 | 
				
			||||||
        super().__init__(*tensors)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Dataset(torch.utils.data.Dataset):
 | 
					class Dataset(torch.utils.data.Dataset):
 | 
				
			||||||
    """Abstract dataset class to be inherited."""
 | 
					    """Abstract dataset class to be inherited."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -44,7 +36,7 @@ class ProtoDataset(Dataset):
 | 
				
			|||||||
    training_file = "training.pt"
 | 
					    training_file = "training.pt"
 | 
				
			||||||
    test_file = "test.pt"
 | 
					    test_file = "test.pt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, root, train=True, download=True, verbose=True):
 | 
					    def __init__(self, root="", train=True, download=True, verbose=True):
 | 
				
			||||||
        super().__init__(root)
 | 
					        super().__init__(root)
 | 
				
			||||||
        self.train = train  # training set or test set
 | 
					        self.train = train  # training set or test set
 | 
				
			||||||
        self.verbose = verbose
 | 
					        self.verbose = verbose
 | 
				
			||||||
@@ -96,3 +88,12 @@ class ProtoDataset(Dataset):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _download(self):
 | 
					    def _download(self):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NumpyDataset(torch.utils.data.TensorDataset):
 | 
				
			||||||
 | 
					    """Create a PyTorch TensorDataset from NumPy arrays."""
 | 
				
			||||||
 | 
					    def __init__(self, data, targets):
 | 
				
			||||||
 | 
					        self.data = torch.Tensor(data)
 | 
				
			||||||
 | 
					        self.targets = torch.LongTensor(targets)
 | 
				
			||||||
 | 
					        tensors = [self.data, self.targets]
 | 
				
			||||||
 | 
					        super().__init__(*tensors)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch functions."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .activations import identity, sigmoid_beta, swish_beta
 | 
					 | 
				
			||||||
from .competitions import knnc, wtac
 | 
					 | 
				
			||||||
from .pooling import *
 | 
					 | 
				
			||||||
@@ -1,28 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch competition functions."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def wtac(distances: torch.Tensor,
 | 
					 | 
				
			||||||
         labels: torch.LongTensor) -> (torch.LongTensor):
 | 
					 | 
				
			||||||
    """Winner-Takes-All-Competition.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Returns the labels corresponding to the winners.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    winning_indices = torch.min(distances, dim=1).indices
 | 
					 | 
				
			||||||
    winning_labels = labels[winning_indices].squeeze()
 | 
					 | 
				
			||||||
    return winning_labels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def knnc(distances: torch.Tensor,
 | 
					 | 
				
			||||||
         labels: torch.LongTensor,
 | 
					 | 
				
			||||||
         k: int = 1) -> (torch.LongTensor):
 | 
					 | 
				
			||||||
    """K-Nearest-Neighbors-Competition.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Returns the labels corresponding to the winners.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    winning_indices = torch.topk(-distances, k=k, dim=1).indices
 | 
					 | 
				
			||||||
    winning_labels = torch.mode(labels[winning_indices], dim=1).values
 | 
					 | 
				
			||||||
    return winning_labels
 | 
					 | 
				
			||||||
@@ -1,259 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch distance functions."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
 | 
					 | 
				
			||||||
                                         equal_int_shape, get_flat)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def squared_euclidean_distance(x, y):
 | 
					 | 
				
			||||||
    r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Compute :math:`{\langle \bm x - \bm y \rangle}_2`
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    **Alias:**
 | 
					 | 
				
			||||||
    ``prototorch.functions.distances.sed``
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    x, y = get_flat(x, y)
 | 
					 | 
				
			||||||
    expanded_x = x.unsqueeze(dim=1)
 | 
					 | 
				
			||||||
    batchwise_difference = y - expanded_x
 | 
					 | 
				
			||||||
    differences_raised = torch.pow(batchwise_difference, 2)
 | 
					 | 
				
			||||||
    distances = torch.sum(differences_raised, axis=2)
 | 
					 | 
				
			||||||
    return distances
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def euclidean_distance(x, y):
 | 
					 | 
				
			||||||
    r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :returns: Distance Tensor of shape :math:`X \times Y`
 | 
					 | 
				
			||||||
    :rtype: `torch.tensor`
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    x, y = get_flat(x, y)
 | 
					 | 
				
			||||||
    distances_raised = squared_euclidean_distance(x, y)
 | 
					 | 
				
			||||||
    distances = torch.sqrt(distances_raised)
 | 
					 | 
				
			||||||
    return distances
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def euclidean_distance_v2(x, y):
 | 
					 | 
				
			||||||
    x, y = get_flat(x, y)
 | 
					 | 
				
			||||||
    diff = y - x.unsqueeze(1)
 | 
					 | 
				
			||||||
    pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
 | 
					 | 
				
			||||||
    # Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
 | 
					 | 
				
			||||||
    # batch diagonal. See:
 | 
					 | 
				
			||||||
    # https://pytorch.org/docs/stable/generated/torch.diagonal.html
 | 
					 | 
				
			||||||
    distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
 | 
					 | 
				
			||||||
    # print(f"{diff.shape=}")  # (nx, ny, ndim)
 | 
					 | 
				
			||||||
    # print(f"{pairwise_distances.shape=}")  # (nx, ny, ny)
 | 
					 | 
				
			||||||
    # print(f"{distances.shape=}")  # (nx, ny)
 | 
					 | 
				
			||||||
    return distances
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def lpnorm_distance(x, y, p):
 | 
					 | 
				
			||||||
    r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
 | 
					 | 
				
			||||||
    Also known as Minkowski distance.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Compute :math:`{\| \bm x - \bm y \|}_p`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Calls ``torch.cdist``
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :param p: p parameter of the lp norm
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    x, y = get_flat(x, y)
 | 
					 | 
				
			||||||
    distances = torch.cdist(x, y, p=p)
 | 
					 | 
				
			||||||
    return distances
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def omega_distance(x, y, omega):
 | 
					 | 
				
			||||||
    r"""Omega distance.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :param `torch.tensor` omega: Two dimensional matrix
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    x, y = get_flat(x, y)
 | 
					 | 
				
			||||||
    projected_x = x @ omega
 | 
					 | 
				
			||||||
    projected_y = y @ omega
 | 
					 | 
				
			||||||
    distances = squared_euclidean_distance(projected_x, projected_y)
 | 
					 | 
				
			||||||
    return distances
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def lomega_distance(x, y, omegas):
 | 
					 | 
				
			||||||
    r"""Localized Omega distance.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    :param `torch.tensor` omegas: Three dimensional matrix
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    x, y = get_flat(x, y)
 | 
					 | 
				
			||||||
    projected_x = x @ omegas
 | 
					 | 
				
			||||||
    projected_y = torch.diagonal(y @ omegas).T
 | 
					 | 
				
			||||||
    expanded_y = torch.unsqueeze(projected_y, dim=1)
 | 
					 | 
				
			||||||
    batchwise_difference = expanded_y - projected_x
 | 
					 | 
				
			||||||
    differences_squared = batchwise_difference**2
 | 
					 | 
				
			||||||
    distances = torch.sum(differences_squared, dim=2)
 | 
					 | 
				
			||||||
    distances = distances.permute(1, 0)
 | 
					 | 
				
			||||||
    return distances
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
 | 
					 | 
				
			||||||
    r"""Computes an euclidean distances matrix given two distinct vectors.
 | 
					 | 
				
			||||||
    last dimension must be the vector dimension!
 | 
					 | 
				
			||||||
    compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    - ``x.shape = (number_of_x_vectors, vector_dim)``
 | 
					 | 
				
			||||||
    - ``y.shape = (number_of_y_vectors, vector_dim)``
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    for tensor in [x, y]:
 | 
					 | 
				
			||||||
        if tensor.ndim != 2:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                "The tensor dimension must be two. You provide: tensor.ndim=" +
 | 
					 | 
				
			||||||
                str(tensor.ndim) + ".")
 | 
					 | 
				
			||||||
    if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
 | 
					 | 
				
			||||||
        raise ValueError(
 | 
					 | 
				
			||||||
            "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
 | 
					 | 
				
			||||||
            + str(tuple(x.shape)[1]) + " and  tuple(y.shape)(y)[1]=" +
 | 
					 | 
				
			||||||
            str(tuple(y.shape)[1]) + ".")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    y = torch.transpose(y)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
 | 
					 | 
				
			||||||
            torch.sum(y**2, axis=0, keepdims=True))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if not squared:
 | 
					 | 
				
			||||||
        if epsilon == 0:
 | 
					 | 
				
			||||||
            diss = torch.sqrt(diss)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            diss = torch.sqrt(torch.max(diss, epsilon))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return diss
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
 | 
					 | 
				
			||||||
    r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    For more info about Tangen distances see
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    DOI:10.1109/IJCNN.2016.7727534.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The subspaces is always assumed as transposed and must be orthogonal!
 | 
					 | 
				
			||||||
    For local non sparse signals subspaces must be provided!
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
 | 
					 | 
				
			||||||
    - shape(protos): proto_number x dim1 x dim2 x ... x dimN
 | 
					 | 
				
			||||||
    - shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN)  x prod(projected_atom_shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    subspace should be orthogonalized
 | 
					 | 
				
			||||||
    Pytorch implementation of Sascha Saralajew's tensorflow code.
 | 
					 | 
				
			||||||
    Translation by Christoph Raab
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
 | 
					 | 
				
			||||||
    proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
 | 
					 | 
				
			||||||
    subspace_int_shape = tuple(subspaces.shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # check if the shapes are correct
 | 
					 | 
				
			||||||
    _check_shapes(signal_int_shape, proto_int_shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    atom_axes = list(range(3, len(signal_int_shape)))
 | 
					 | 
				
			||||||
    # for sparse signals, we use the memory efficient implementation
 | 
					 | 
				
			||||||
    if signal_int_shape[1] == 1:
 | 
					 | 
				
			||||||
        signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if len(atom_axes) > 1:
 | 
					 | 
				
			||||||
            protos = torch.reshape(protos, [proto_shape[0], -1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if subspaces.ndim == 2:
 | 
					 | 
				
			||||||
            # clean solution without map if the matrix_scope is global
 | 
					 | 
				
			||||||
            projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
 | 
					 | 
				
			||||||
                subspaces, torch.transpose(subspaces))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            projected_signals = torch.dot(signals, projectors)
 | 
					 | 
				
			||||||
            projected_protos = torch.dot(protos, projectors)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            diss = euclidean_distance_matrix(projected_signals,
 | 
					 | 
				
			||||||
                                             projected_protos,
 | 
					 | 
				
			||||||
                                             squared=squared,
 | 
					 | 
				
			||||||
                                             epsilon=epsilon)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            diss = torch.reshape(
 | 
					 | 
				
			||||||
                diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            return torch.permute(diss, [0, 2, 1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # no solution without map possible --> memory efficient but slow!
 | 
					 | 
				
			||||||
            projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
 | 
					 | 
				
			||||||
                subspaces,
 | 
					 | 
				
			||||||
                subspaces)  # K.batch_dot(subspaces, subspaces, [2, 2])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            projected_protos = (protos @ subspaces
 | 
					 | 
				
			||||||
                                ).T  # K.batch_dot(projectors, protos, [1, 1]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            def projected_norm(projector):
 | 
					 | 
				
			||||||
                return torch.sum(torch.dot(signals, projector)**2, axis=1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            diss = (torch.transpose(map(projected_norm, projectors)) -
 | 
					 | 
				
			||||||
                    2 * torch.dot(signals, projected_protos) +
 | 
					 | 
				
			||||||
                    torch.sum(projected_protos**2, axis=0, keepdims=True))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if not squared:
 | 
					 | 
				
			||||||
                if epsilon == 0:
 | 
					 | 
				
			||||||
                    diss = torch.sqrt(diss)
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    diss = torch.sqrt(torch.max(diss, epsilon))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            diss = torch.reshape(
 | 
					 | 
				
			||||||
                diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            return torch.permute(diss, [0, 2, 1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        signals = signals.permute([0, 2, 1] + atom_axes)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        diff = signals - protos
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # global tangent space
 | 
					 | 
				
			||||||
        if subspaces.ndim == 2:
 | 
					 | 
				
			||||||
            # Scope Projectors
 | 
					 | 
				
			||||||
            projectors = subspaces  #
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Scope: Tangentspace Projections
 | 
					 | 
				
			||||||
            diff = torch.reshape(
 | 
					 | 
				
			||||||
                diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
 | 
					 | 
				
			||||||
            projected_diff = diff @ projectors
 | 
					 | 
				
			||||||
            projected_diff = torch.reshape(
 | 
					 | 
				
			||||||
                projected_diff,
 | 
					 | 
				
			||||||
                (signal_shape[0], signal_shape[2], signal_shape[1]) +
 | 
					 | 
				
			||||||
                signal_shape[3:],
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            diss = torch.norm(projected_diff, 2, dim=-1)
 | 
					 | 
				
			||||||
            return diss.permute([0, 2, 1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # local tangent spaces
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            # Scope: Calculate Projectors
 | 
					 | 
				
			||||||
            projectors = subspaces
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Scope: Tangentspace Projections
 | 
					 | 
				
			||||||
            diff = torch.reshape(
 | 
					 | 
				
			||||||
                diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
 | 
					 | 
				
			||||||
            diff = diff.permute([1, 0, 2])
 | 
					 | 
				
			||||||
            projected_diff = torch.bmm(diff, projectors)
 | 
					 | 
				
			||||||
            projected_diff = torch.reshape(
 | 
					 | 
				
			||||||
                projected_diff,
 | 
					 | 
				
			||||||
                (signal_shape[1], signal_shape[0], signal_shape[2]) +
 | 
					 | 
				
			||||||
                signal_shape[3:],
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            diss = torch.norm(projected_diff, 2, dim=-1)
 | 
					 | 
				
			||||||
            return diss.permute([1, 0, 2]).squeeze(-1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Aliases
 | 
					 | 
				
			||||||
sed = squared_euclidean_distance
 | 
					 | 
				
			||||||
@@ -1,94 +0,0 @@
 | 
				
			|||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_flat(*args):
 | 
					 | 
				
			||||||
    rv = [x.view(x.size(0), -1) for x in args]
 | 
					 | 
				
			||||||
    return rv
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def calculate_prototype_accuracy(y_pred, y_true, plabels):
 | 
					 | 
				
			||||||
    """Computes the accuracy of a prototype based model.
 | 
					 | 
				
			||||||
    via Winner-Takes-All rule.
 | 
					 | 
				
			||||||
    Requirement:
 | 
					 | 
				
			||||||
    y_pred.shape == y_true.shape
 | 
					 | 
				
			||||||
    unique(y_pred) in plabels
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    with torch.no_grad():
 | 
					 | 
				
			||||||
        idx = torch.argmin(y_pred, axis=1)
 | 
					 | 
				
			||||||
        return torch.true_divide(torch.sum(y_true == plabels[idx]),
 | 
					 | 
				
			||||||
                                 len(y_pred)) * 100
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def predict_label(y_pred, plabels):
 | 
					 | 
				
			||||||
    r""" Predicts labels given a prediction of a prototype based model.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    with torch.no_grad():
 | 
					 | 
				
			||||||
        return plabels[torch.argmin(y_pred, 1)]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def mixed_shape(inputs):
 | 
					 | 
				
			||||||
    if not torch.is_tensor(inputs):
 | 
					 | 
				
			||||||
        raise ValueError("Input must be a tensor.")
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        int_shape = list(inputs.shape)
 | 
					 | 
				
			||||||
        # sometimes int_shape returns mixed integer types
 | 
					 | 
				
			||||||
        int_shape = [int(i) if i is not None else i for i in int_shape]
 | 
					 | 
				
			||||||
        tensor_shape = inputs.shape
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for i, s in enumerate(int_shape):
 | 
					 | 
				
			||||||
            if s is None:
 | 
					 | 
				
			||||||
                int_shape[i] = tensor_shape[i]
 | 
					 | 
				
			||||||
        return tuple(int_shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def equal_int_shape(shape_1, shape_2):
 | 
					 | 
				
			||||||
    if not isinstance(shape_1,
 | 
					 | 
				
			||||||
                      (tuple, list)) or not isinstance(shape_2, (tuple, list)):
 | 
					 | 
				
			||||||
        raise ValueError("Input shapes must list or tuple.")
 | 
					 | 
				
			||||||
    for shape in [shape_1, shape_2]:
 | 
					 | 
				
			||||||
        if not all([isinstance(x, int) or x is None for x in shape]):
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                "Input shapes must be list or tuple of int and None values.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if len(shape_1) != len(shape_2):
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        for axis, value in enumerate(shape_1):
 | 
					 | 
				
			||||||
            if value is not None and shape_2[axis] not in {value, None}:
 | 
					 | 
				
			||||||
                return False
 | 
					 | 
				
			||||||
        return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _check_shapes(signal_int_shape, proto_int_shape):
 | 
					 | 
				
			||||||
    if len(signal_int_shape) < 4:
 | 
					 | 
				
			||||||
        raise ValueError(
 | 
					 | 
				
			||||||
            "The number of signal dimensions must be >=4. You provide: " +
 | 
					 | 
				
			||||||
            str(len(signal_int_shape)))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if len(proto_int_shape) < 2:
 | 
					 | 
				
			||||||
        raise ValueError(
 | 
					 | 
				
			||||||
            "The number of proto dimensions must be >=2. You provide: " +
 | 
					 | 
				
			||||||
            str(len(proto_int_shape)))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]):
 | 
					 | 
				
			||||||
        raise ValueError(
 | 
					 | 
				
			||||||
            "The atom shape of signals must be equal protos. You provide: signals.shape[3:]="
 | 
					 | 
				
			||||||
            + str(signal_int_shape[3:]) + " != protos.shape[1:]=" +
 | 
					 | 
				
			||||||
            str(proto_int_shape[1:]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # not a sparse signal
 | 
					 | 
				
			||||||
    if signal_int_shape[1] != 1:
 | 
					 | 
				
			||||||
        if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]):
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                "If the signal is not sparse, the number of prototypes must be equal in signals and "
 | 
					 | 
				
			||||||
                "protos. You provide: " + str(signal_int_shape[1]) + " != " +
 | 
					 | 
				
			||||||
                str(proto_int_shape[0]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _int_and_mixed_shape(tensor):
 | 
					 | 
				
			||||||
    shape = mixed_shape(tensor)
 | 
					 | 
				
			||||||
    int_shape = tuple(i if isinstance(i, int) else None for i in shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return shape, int_shape
 | 
					 | 
				
			||||||
@@ -1,107 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch initialization functions."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from itertools import chain
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
INITIALIZERS = dict()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def register_initializer(function):
 | 
					 | 
				
			||||||
    """Add the initializer to the registry."""
 | 
					 | 
				
			||||||
    INITIALIZERS[function.__name__] = function
 | 
					 | 
				
			||||||
    return function
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def labels_from(distribution, one_hot=True):
 | 
					 | 
				
			||||||
    """Takes a distribution tensor and returns a labels tensor."""
 | 
					 | 
				
			||||||
    num_classes = distribution.shape[0]
 | 
					 | 
				
			||||||
    llist = [[i] * n for i, n in zip(range(num_classes), distribution)]
 | 
					 | 
				
			||||||
    # labels = [l for cl in llist for l in cl]  # flatten the list of lists
 | 
					 | 
				
			||||||
    flat_llist = list(chain(*llist))  # flatten label list with itertools.chain
 | 
					 | 
				
			||||||
    plabels = torch.tensor(flat_llist, requires_grad=False)
 | 
					 | 
				
			||||||
    if one_hot:
 | 
					 | 
				
			||||||
        return torch.eye(num_classes)[plabels]
 | 
					 | 
				
			||||||
    return plabels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@register_initializer
 | 
					 | 
				
			||||||
def ones(x_train, y_train, prototype_distribution, one_hot=True):
 | 
					 | 
				
			||||||
    num_protos = torch.sum(prototype_distribution)
 | 
					 | 
				
			||||||
    protos = torch.ones(num_protos, *x_train.shape[1:])
 | 
					 | 
				
			||||||
    plabels = labels_from(prototype_distribution, one_hot)
 | 
					 | 
				
			||||||
    return protos, plabels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@register_initializer
 | 
					 | 
				
			||||||
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
 | 
					 | 
				
			||||||
    num_protos = torch.sum(prototype_distribution)
 | 
					 | 
				
			||||||
    protos = torch.zeros(num_protos, *x_train.shape[1:])
 | 
					 | 
				
			||||||
    plabels = labels_from(prototype_distribution, one_hot)
 | 
					 | 
				
			||||||
    return protos, plabels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@register_initializer
 | 
					 | 
				
			||||||
def rand(x_train, y_train, prototype_distribution, one_hot=True):
 | 
					 | 
				
			||||||
    num_protos = torch.sum(prototype_distribution)
 | 
					 | 
				
			||||||
    protos = torch.rand(num_protos, *x_train.shape[1:])
 | 
					 | 
				
			||||||
    plabels = labels_from(prototype_distribution, one_hot)
 | 
					 | 
				
			||||||
    return protos, plabels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@register_initializer
 | 
					 | 
				
			||||||
def randn(x_train, y_train, prototype_distribution, one_hot=True):
 | 
					 | 
				
			||||||
    num_protos = torch.sum(prototype_distribution)
 | 
					 | 
				
			||||||
    protos = torch.randn(num_protos, *x_train.shape[1:])
 | 
					 | 
				
			||||||
    plabels = labels_from(prototype_distribution, one_hot)
 | 
					 | 
				
			||||||
    return protos, plabels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@register_initializer
 | 
					 | 
				
			||||||
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
 | 
					 | 
				
			||||||
    num_protos = torch.sum(prototype_distribution)
 | 
					 | 
				
			||||||
    pdim = x_train.shape[1]
 | 
					 | 
				
			||||||
    protos = torch.empty(num_protos, pdim)
 | 
					 | 
				
			||||||
    plabels = labels_from(prototype_distribution, one_hot)
 | 
					 | 
				
			||||||
    for i, label in enumerate(plabels):
 | 
					 | 
				
			||||||
        matcher = torch.eq(label.unsqueeze(dim=0), y_train)
 | 
					 | 
				
			||||||
        if one_hot:
 | 
					 | 
				
			||||||
            num_classes = y_train.size()[1]
 | 
					 | 
				
			||||||
            matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
 | 
					 | 
				
			||||||
        xl = x_train[matcher]
 | 
					 | 
				
			||||||
        mean_xl = torch.mean(xl, dim=0)
 | 
					 | 
				
			||||||
        protos[i] = mean_xl
 | 
					 | 
				
			||||||
    plabels = labels_from(prototype_distribution, one_hot=one_hot)
 | 
					 | 
				
			||||||
    return protos, plabels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@register_initializer
 | 
					 | 
				
			||||||
def stratified_random(x_train,
 | 
					 | 
				
			||||||
                      y_train,
 | 
					 | 
				
			||||||
                      prototype_distribution,
 | 
					 | 
				
			||||||
                      one_hot=True,
 | 
					 | 
				
			||||||
                      epsilon=1e-7):
 | 
					 | 
				
			||||||
    num_protos = torch.sum(prototype_distribution)
 | 
					 | 
				
			||||||
    pdim = x_train.shape[1]
 | 
					 | 
				
			||||||
    protos = torch.empty(num_protos, pdim)
 | 
					 | 
				
			||||||
    plabels = labels_from(prototype_distribution, one_hot)
 | 
					 | 
				
			||||||
    for i, label in enumerate(plabels):
 | 
					 | 
				
			||||||
        matcher = torch.eq(label.unsqueeze(dim=0), y_train)
 | 
					 | 
				
			||||||
        if one_hot:
 | 
					 | 
				
			||||||
            num_classes = y_train.size()[1]
 | 
					 | 
				
			||||||
            matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
 | 
					 | 
				
			||||||
        xl = x_train[matcher]
 | 
					 | 
				
			||||||
        rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
 | 
					 | 
				
			||||||
        random_xl = xl[rand_index]
 | 
					 | 
				
			||||||
        protos[i] = random_xl + epsilon
 | 
					 | 
				
			||||||
    plabels = labels_from(prototype_distribution, one_hot=one_hot)
 | 
					 | 
				
			||||||
    return protos, plabels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_initializer(funcname):
 | 
					 | 
				
			||||||
    """Deserialize the initializer."""
 | 
					 | 
				
			||||||
    if callable(funcname):
 | 
					 | 
				
			||||||
        return funcname
 | 
					 | 
				
			||||||
    if funcname in INITIALIZERS:
 | 
					 | 
				
			||||||
        return INITIALIZERS.get(funcname)
 | 
					 | 
				
			||||||
    raise NameError(f"Initializer {funcname} was not found.")
 | 
					 | 
				
			||||||
@@ -1,35 +0,0 @@
 | 
				
			|||||||
# -*- coding: utf-8 -*-
 | 
					 | 
				
			||||||
from __future__ import absolute_import, division, print_function
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def orthogonalization(tensors):
 | 
					 | 
				
			||||||
    r""" Orthogonalization of a given tensor via polar decomposition.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    u, _, v = torch.svd(tensors, compute_uv=True)
 | 
					 | 
				
			||||||
    u_shape = tuple(list(u.shape))
 | 
					 | 
				
			||||||
    v_shape = tuple(list(v.shape))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # reshape to (num x N x M)
 | 
					 | 
				
			||||||
    u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
 | 
					 | 
				
			||||||
    v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    out = u @ v.permute([0, 2, 1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return out
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def trace_normalization(tensors):
 | 
					 | 
				
			||||||
    r""" Trace normalization
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    epsilon = torch.tensor([1e-10], dtype=torch.float64)
 | 
					 | 
				
			||||||
    # Scope trace_normalization
 | 
					 | 
				
			||||||
    constant = torch.trace(tensors)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if epsilon != 0:
 | 
					 | 
				
			||||||
        constant = torch.max(constant, epsilon)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return tensors / constant
 | 
					 | 
				
			||||||
@@ -1,32 +0,0 @@
 | 
				
			|||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Functions
 | 
					 | 
				
			||||||
def gaussian(distances, variance):
 | 
					 | 
				
			||||||
    return torch.exp(-(distances * distances) / (2 * variance))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def rank_scaled_gaussian(distances, lambd):
 | 
					 | 
				
			||||||
    order = torch.argsort(distances, dim=1)
 | 
					 | 
				
			||||||
    ranks = torch.argsort(order, dim=1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return torch.exp(-torch.exp(-ranks / lambd) * distances)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Modules
 | 
					 | 
				
			||||||
class GaussianPrior(torch.nn.Module):
 | 
					 | 
				
			||||||
    def __init__(self, variance):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.variance = variance
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, distances):
 | 
					 | 
				
			||||||
        return gaussian(distances, self.variance)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class RankScaledGaussianPrior(torch.nn.Module):
 | 
					 | 
				
			||||||
    def __init__(self, lambd):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.lambd = lambd
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, distances):
 | 
					 | 
				
			||||||
        return rank_scaled_gaussian(distances, self.lambd)
 | 
					 | 
				
			||||||
@@ -1,5 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch modules."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .competitions import *
 | 
					 | 
				
			||||||
from .pooling import *
 | 
					 | 
				
			||||||
from .wrappers import LambdaLayer, LossLayer
 | 
					 | 
				
			||||||
@@ -1,42 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch Competition Modules."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from prototorch.functions.competitions import knnc, wtac
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class WTAC(torch.nn.Module):
 | 
					 | 
				
			||||||
    """Winner-Takes-All-Competition Layer.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Thin wrapper over the `wtac` function.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    def forward(self, distances, labels):
 | 
					 | 
				
			||||||
        return wtac(distances, labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class LTAC(torch.nn.Module):
 | 
					 | 
				
			||||||
    """Loser-Takes-All-Competition Layer.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Thin wrapper over the `wtac` function.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    def forward(self, probs, labels):
 | 
					 | 
				
			||||||
        return wtac(-1.0 * probs, labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class KNNC(torch.nn.Module):
 | 
					 | 
				
			||||||
    """K-Nearest-Neighbors-Competition.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Thin wrapper over the `knnc` function.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    def __init__(self, k=1, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(**kwargs)
 | 
					 | 
				
			||||||
        self.k = k
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, distances, labels):
 | 
					 | 
				
			||||||
        return knnc(distances, labels, k=self.k)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def extra_repr(self):
 | 
					 | 
				
			||||||
        return f"k: {self.k}"
 | 
					 | 
				
			||||||
@@ -1,59 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch losses."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from prototorch.functions.activations import get_activation
 | 
					 | 
				
			||||||
from prototorch.functions.losses import glvq_loss
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class GLVQLoss(torch.nn.Module):
 | 
					 | 
				
			||||||
    def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(**kwargs)
 | 
					 | 
				
			||||||
        self.margin = margin
 | 
					 | 
				
			||||||
        self.squashing = get_activation(squashing)
 | 
					 | 
				
			||||||
        self.beta = torch.tensor(beta)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, outputs, targets):
 | 
					 | 
				
			||||||
        distances, plabels = outputs
 | 
					 | 
				
			||||||
        mu = glvq_loss(distances, targets, prototype_labels=plabels)
 | 
					 | 
				
			||||||
        batch_loss = self.squashing(mu + self.margin, beta=self.beta)
 | 
					 | 
				
			||||||
        return torch.sum(batch_loss, dim=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class NeuralGasEnergy(torch.nn.Module):
 | 
					 | 
				
			||||||
    def __init__(self, lm, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(**kwargs)
 | 
					 | 
				
			||||||
        self.lm = lm
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, d):
 | 
					 | 
				
			||||||
        order = torch.argsort(d, dim=1)
 | 
					 | 
				
			||||||
        ranks = torch.argsort(order, dim=1)
 | 
					 | 
				
			||||||
        cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return cost, order
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def extra_repr(self):
 | 
					 | 
				
			||||||
        return f"lambda: {self.lm}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    def _nghood_fn(rankings, lm):
 | 
					 | 
				
			||||||
        return torch.exp(-rankings / lm)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class GrowingNeuralGasEnergy(NeuralGasEnergy):
 | 
					 | 
				
			||||||
    def __init__(self, topology_layer, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(**kwargs)
 | 
					 | 
				
			||||||
        self.topology_layer = topology_layer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    def _nghood_fn(rankings, topology):
 | 
					 | 
				
			||||||
        winner = rankings[:, 0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        weights = torch.zeros_like(rankings, dtype=torch.float)
 | 
					 | 
				
			||||||
        weights[torch.arange(rankings.shape[0]), winner] = 1.0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        neighbours = topology.get_neighbours(winner)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        weights[neighbours] = 0.1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return weights
 | 
					 | 
				
			||||||
@@ -1,170 +0,0 @@
 | 
				
			|||||||
import torch
 | 
					 | 
				
			||||||
from torch import nn
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
 | 
					 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance_matrix
 | 
					 | 
				
			||||||
from prototorch.functions.normalization import orthogonalization
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class GTLVQ(nn.Module):
 | 
					 | 
				
			||||||
    r""" Generalized Tangent Learning Vector Quantization
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Parameters
 | 
					 | 
				
			||||||
    ----------
 | 
					 | 
				
			||||||
    num_classes: int
 | 
					 | 
				
			||||||
        Number of classes of the given classification problem.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
 | 
					 | 
				
			||||||
        Subspace data for the point approximation, required
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
 | 
					 | 
				
			||||||
        prototype data for initalization of the prototypes used in GTLVQ.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    subspace_size: int (default=256,optional)
 | 
					 | 
				
			||||||
        Subspace dimension of the Projectors. Currently only supported
 | 
					 | 
				
			||||||
        with tagnent_projection_type=global.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    tangent_projection_type: string
 | 
					 | 
				
			||||||
        Specifies the tangent projection type
 | 
					 | 
				
			||||||
        options:    local
 | 
					 | 
				
			||||||
                    local_proj
 | 
					 | 
				
			||||||
                    global
 | 
					 | 
				
			||||||
        local: computes the tangent distances without emphasizing projected
 | 
					 | 
				
			||||||
        data. Only distances are available
 | 
					 | 
				
			||||||
        local_proj: computs tangent distances and returns the projected data
 | 
					 | 
				
			||||||
        for further use. Be careful: data is repeated by number of prototypes
 | 
					 | 
				
			||||||
        global: Number of subspaces is set to one and every prototypes
 | 
					 | 
				
			||||||
        uses the same.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    prototypes_per_class: int (default=2,optional)
 | 
					 | 
				
			||||||
    Number of prototypes per class
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    feature_dim: int (default=256)
 | 
					 | 
				
			||||||
    Dimensionality of the feature space specified as integer.
 | 
					 | 
				
			||||||
    Prototype dimension.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Notes
 | 
					 | 
				
			||||||
    -----
 | 
					 | 
				
			||||||
    The GTLVQ [1] is a prototype-based classification learning model. The
 | 
					 | 
				
			||||||
    GTLVQ uses the Tangent-Distances for a local point approximation
 | 
					 | 
				
			||||||
    of an assumed data manifold via prototypial representations.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The GTLVQ requires subspace projectors for transforming the data
 | 
					 | 
				
			||||||
    and prototypes into the affine subspace. Every prototype is
 | 
					 | 
				
			||||||
    equipped with a specific subpspace and represents a point
 | 
					 | 
				
			||||||
    approximation of the assumed manifold.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    In practice prototypes and data are projected on this manifold
 | 
					 | 
				
			||||||
    and pairwise euclidean distance computes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    References
 | 
					 | 
				
			||||||
    ----------
 | 
					 | 
				
			||||||
    .. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
 | 
					 | 
				
			||||||
    in classification based on manifolc. models and its relation
 | 
					 | 
				
			||||||
    to tangent metric learning. In: 2017 International Joint
 | 
					 | 
				
			||||||
    Conference on Neural Networks (IJCNN).
 | 
					 | 
				
			||||||
    Bd. 2017-May : IEEE, 2017, S. 1756–1765
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        num_classes,
 | 
					 | 
				
			||||||
        subspace_data=None,
 | 
					 | 
				
			||||||
        prototype_data=None,
 | 
					 | 
				
			||||||
        subspace_size=256,
 | 
					 | 
				
			||||||
        tangent_projection_type="local",
 | 
					 | 
				
			||||||
        prototypes_per_class=2,
 | 
					 | 
				
			||||||
        feature_dim=256,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        super(GTLVQ, self).__init__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.num_protos = num_classes * prototypes_per_class
 | 
					 | 
				
			||||||
        self.num_protos_class = prototypes_per_class
 | 
					 | 
				
			||||||
        self.subspace_size = feature_dim if subspace_size is None else subspace_size
 | 
					 | 
				
			||||||
        self.feature_dim = feature_dim
 | 
					 | 
				
			||||||
        self.num_classes = num_classes
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cls_initializer = StratifiedMeanInitializer(prototype_data)
 | 
					 | 
				
			||||||
        cls_distribution = {
 | 
					 | 
				
			||||||
            "num_classes": num_classes,
 | 
					 | 
				
			||||||
            "prototypes_per_class": prototypes_per_class,
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.cls = LabeledComponents(cls_distribution, cls_initializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if subspace_data is None:
 | 
					 | 
				
			||||||
            raise ValueError("Init Data must be specified!")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.tpt = tangent_projection_type
 | 
					 | 
				
			||||||
        with torch.no_grad():
 | 
					 | 
				
			||||||
            if self.tpt == "local":
 | 
					 | 
				
			||||||
                self.init_local_subspace(subspace_data, subspace_size,
 | 
					 | 
				
			||||||
                                         self.num_protos)
 | 
					 | 
				
			||||||
            elif self.tpt == "global":
 | 
					 | 
				
			||||||
                self.init_gobal_subspace(subspace_data, subspace_size)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                self.subspaces = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, x):
 | 
					 | 
				
			||||||
        if self.tpt == "local":
 | 
					 | 
				
			||||||
            dis = self.local_tangent_distances(x)
 | 
					 | 
				
			||||||
        elif self.tpt == "gloabl":
 | 
					 | 
				
			||||||
            dis = self.global_tangent_distances(x)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            dis = (x @ self.cls.prototypes.T) / (
 | 
					 | 
				
			||||||
                torch.norm(x, dim=1, keepdim=True) @ torch.norm(
 | 
					 | 
				
			||||||
                    self.cls.prototypes, dim=1, keepdim=True).T)
 | 
					 | 
				
			||||||
        return dis
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def init_gobal_subspace(self, data, num_subspaces):
 | 
					 | 
				
			||||||
        _, _, v = torch.svd(data)
 | 
					 | 
				
			||||||
        subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
 | 
					 | 
				
			||||||
        subspaces = subspace[:, :num_subspaces]
 | 
					 | 
				
			||||||
        self.subspaces = nn.Parameter(subspaces, requires_grad=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def init_local_subspace(self, data, num_subspaces, num_protos):
 | 
					 | 
				
			||||||
        data = data - torch.mean(data, dim=0)
 | 
					 | 
				
			||||||
        _, _, v = torch.svd(data, some=False)
 | 
					 | 
				
			||||||
        v = v[:, :num_subspaces]
 | 
					 | 
				
			||||||
        subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
 | 
					 | 
				
			||||||
        self.subspaces = nn.Parameter(subspaces, requires_grad=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def global_tangent_distances(self, x):
 | 
					 | 
				
			||||||
        # Tangent Projection
 | 
					 | 
				
			||||||
        x, projected_prototypes = (
 | 
					 | 
				
			||||||
            x @ self.subspaces,
 | 
					 | 
				
			||||||
            self.cls.prototypes @ self.subspaces,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        # Euclidean Distance
 | 
					 | 
				
			||||||
        return euclidean_distance_matrix(x, projected_prototypes)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def local_tangent_distances(self, x):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Tangent Distance
 | 
					 | 
				
			||||||
        x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components,
 | 
					 | 
				
			||||||
                                  x.size(-1))
 | 
					 | 
				
			||||||
        protos = self.cls()[0].unsqueeze(0).expand(x.size(0),
 | 
					 | 
				
			||||||
                                                   self.cls.num_components,
 | 
					 | 
				
			||||||
                                                   x.size(-1))
 | 
					 | 
				
			||||||
        projectors = torch.eye(
 | 
					 | 
				
			||||||
            self.subspaces.shape[-2], device=x.device) - torch.bmm(
 | 
					 | 
				
			||||||
                self.subspaces, self.subspaces.permute([0, 2, 1]))
 | 
					 | 
				
			||||||
        diff = (x - protos)
 | 
					 | 
				
			||||||
        diff = diff.permute([1, 0, 2])
 | 
					 | 
				
			||||||
        diff = torch.bmm(diff, projectors)
 | 
					 | 
				
			||||||
        diff = torch.norm(diff, 2, dim=-1).T
 | 
					 | 
				
			||||||
        return diff
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_parameters(self):
 | 
					 | 
				
			||||||
        return {
 | 
					 | 
				
			||||||
            "params": self.cls.components,
 | 
					 | 
				
			||||||
        }, {
 | 
					 | 
				
			||||||
            "params": self.subspaces
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def orthogonalize_subspace(self):
 | 
					 | 
				
			||||||
        if self.subspaces is not None:
 | 
					 | 
				
			||||||
            with torch.no_grad():
 | 
					 | 
				
			||||||
                ortho_subpsaces = (orthogonalization(self.subspaces)
 | 
					 | 
				
			||||||
                                   if self.tpt == "global" else
 | 
					 | 
				
			||||||
                                   torch.nn.init.orthogonal_(self.subspaces))
 | 
					 | 
				
			||||||
                self.subspaces.copy_(ortho_subpsaces)
 | 
					 | 
				
			||||||
@@ -1,32 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch Pooling Modules."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from prototorch.functions.pooling import (stratified_max_pooling,
 | 
					 | 
				
			||||||
                                          stratified_min_pooling,
 | 
					 | 
				
			||||||
                                          stratified_prod_pooling,
 | 
					 | 
				
			||||||
                                          stratified_sum_pooling)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StratifiedSumPooling(torch.nn.Module):
 | 
					 | 
				
			||||||
    """Thin wrapper over the `stratified_sum_pooling` function."""
 | 
					 | 
				
			||||||
    def forward(self, values, labels):
 | 
					 | 
				
			||||||
        return stratified_sum_pooling(values, labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StratifiedProdPooling(torch.nn.Module):
 | 
					 | 
				
			||||||
    """Thin wrapper over the `stratified_prod_pooling` function."""
 | 
					 | 
				
			||||||
    def forward(self, values, labels):
 | 
					 | 
				
			||||||
        return stratified_prod_pooling(values, labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StratifiedMinPooling(torch.nn.Module):
 | 
					 | 
				
			||||||
    """Thin wrapper over the `stratified_min_pooling` function."""
 | 
					 | 
				
			||||||
    def forward(self, values, labels):
 | 
					 | 
				
			||||||
        return stratified_min_pooling(values, labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StratifiedMaxPooling(torch.nn.Module):
 | 
					 | 
				
			||||||
    """Thin wrapper over the `stratified_max_pooling` function."""
 | 
					 | 
				
			||||||
    def forward(self, values, labels):
 | 
					 | 
				
			||||||
        return stratified_max_pooling(values, labels)
 | 
					 | 
				
			||||||
							
								
								
									
										4
									
								
								prototorch/nn/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								prototorch/nn/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch Neural Network Module"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .activations import *
 | 
				
			||||||
 | 
					from .wrappers import *
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
"""ProtoTorch activation functions."""
 | 
					"""ProtoTorch activations"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -57,6 +57,10 @@ def get_activation(funcname):
 | 
				
			|||||||
    """Deserialize the activation function."""
 | 
					    """Deserialize the activation function."""
 | 
				
			||||||
    if callable(funcname):
 | 
					    if callable(funcname):
 | 
				
			||||||
        return funcname
 | 
					        return funcname
 | 
				
			||||||
    if funcname in ACTIVATIONS:
 | 
					    elif funcname in ACTIVATIONS:
 | 
				
			||||||
        return ACTIVATIONS.get(funcname)
 | 
					        return ACTIVATIONS.get(funcname)
 | 
				
			||||||
    raise NameError(f"Activation {funcname} was not found.")
 | 
					    else:
 | 
				
			||||||
 | 
					        emsg = f"Unable to find matching function for `{funcname}` " \
 | 
				
			||||||
 | 
					            f"in `prototorch.nn.activations`. "
 | 
				
			||||||
 | 
					        helpmsg = f"Possible values are {list(ACTIVATIONS.keys())}."
 | 
				
			||||||
 | 
					        raise NameError(emsg + helpmsg)
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
"""ProtoTorch Wrappers."""
 | 
					"""ProtoTorch wrappers."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -0,0 +1,8 @@
 | 
				
			|||||||
 | 
					"""ProtoFlow utils module"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .colors import hex_to_rgb, rgb_to_hex
 | 
				
			||||||
 | 
					from .utils import (
 | 
				
			||||||
 | 
					    mesh2d,
 | 
				
			||||||
 | 
					    parse_data_arg,
 | 
				
			||||||
 | 
					    parse_distribution,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,46 +0,0 @@
 | 
				
			|||||||
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from collections import defaultdict
 | 
					 | 
				
			||||||
from typing import Dict, List
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from matplotlib.animation import ArtistAnimation
 | 
					 | 
				
			||||||
from matplotlib.artist import Artist
 | 
					 | 
				
			||||||
from matplotlib.figure import Figure
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__version__ = "0.2.0"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Camera:
 | 
					 | 
				
			||||||
    """Make animations easier."""
 | 
					 | 
				
			||||||
    def __init__(self, figure: Figure) -> None:
 | 
					 | 
				
			||||||
        """Create camera from matplotlib figure."""
 | 
					 | 
				
			||||||
        self._figure = figure
 | 
					 | 
				
			||||||
        # need to keep track off artists for each axis
 | 
					 | 
				
			||||||
        self._offsets: Dict[str, Dict[int, int]] = {
 | 
					 | 
				
			||||||
            k: defaultdict(int)
 | 
					 | 
				
			||||||
            for k in
 | 
					 | 
				
			||||||
            ["collections", "patches", "lines", "texts", "artists", "images"]
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        self._photos: List[List[Artist]] = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def snap(self) -> List[Artist]:
 | 
					 | 
				
			||||||
        """Capture current state of the figure."""
 | 
					 | 
				
			||||||
        frame_artists: List[Artist] = []
 | 
					 | 
				
			||||||
        for i, axis in enumerate(self._figure.axes):
 | 
					 | 
				
			||||||
            if axis.legend_ is not None:
 | 
					 | 
				
			||||||
                axis.add_artist(axis.legend_)
 | 
					 | 
				
			||||||
            for name in self._offsets:
 | 
					 | 
				
			||||||
                new_artists = getattr(axis, name)[self._offsets[name][i]:]
 | 
					 | 
				
			||||||
                frame_artists += new_artists
 | 
					 | 
				
			||||||
                self._offsets[name][i] += len(new_artists)
 | 
					 | 
				
			||||||
        self._photos.append(frame_artists)
 | 
					 | 
				
			||||||
        return frame_artists
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def animate(self, *args, **kwargs) -> ArtistAnimation:
 | 
					 | 
				
			||||||
        """Animate the snapshots taken.
 | 
					 | 
				
			||||||
        Uses matplotlib.animation.ArtistAnimation
 | 
					 | 
				
			||||||
        Returns
 | 
					 | 
				
			||||||
        -------
 | 
					 | 
				
			||||||
        ArtistAnimation
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return ArtistAnimation(self._figure, self._photos, *args, **kwargs)
 | 
					 | 
				
			||||||
@@ -1,78 +1,15 @@
 | 
				
			|||||||
"""ProtoFlow color utilities."""
 | 
					"""ProtoFlow color utilities"""
 | 
				
			||||||
 | 
					 | 
				
			||||||
import matplotlib.lines as mlines
 | 
					 | 
				
			||||||
from matplotlib import cm
 | 
					 | 
				
			||||||
from matplotlib.colors import Normalize, to_hex, to_rgb
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def color_scheme(n,
 | 
					def hex_to_rgb(hex_values):
 | 
				
			||||||
                 cmap="viridis",
 | 
					    for v in hex_values:
 | 
				
			||||||
                 form="hex",
 | 
					        v = v.lstrip('#')
 | 
				
			||||||
                 tikz=False,
 | 
					        lv = len(v)
 | 
				
			||||||
                 zero_indexed=False):
 | 
					        c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)]
 | 
				
			||||||
    """Return *n* colors from the color scheme.
 | 
					        yield c
 | 
				
			||||||
 | 
					 | 
				
			||||||
    Arguments:
 | 
					 | 
				
			||||||
        n (int): number of colors to return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Keyword Arguments:
 | 
					 | 
				
			||||||
        cmap (str): Name of a matplotlib `colormap\
 | 
					 | 
				
			||||||
            <https://matplotlib.org/3.1.1/gallery/color/colormap_reference.html>`_.
 | 
					 | 
				
			||||||
        form (str): Colorformat (supports "hex" and "rgb").
 | 
					 | 
				
			||||||
        tikz (bool): Output as `TikZ <https://github.com/pgf-tikz/pgf>`_
 | 
					 | 
				
			||||||
            command.
 | 
					 | 
				
			||||||
        zero_indexed (bool): Use zero indexing for output array.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Returns:
 | 
					 | 
				
			||||||
        (list): List of colors
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    cmap = cm.get_cmap(cmap)
 | 
					 | 
				
			||||||
    colornorm = Normalize(vmin=1, vmax=n)
 | 
					 | 
				
			||||||
    hex_map = dict()
 | 
					 | 
				
			||||||
    rgb_map = dict()
 | 
					 | 
				
			||||||
    for cl in range(1, n + 1):
 | 
					 | 
				
			||||||
        if zero_indexed:
 | 
					 | 
				
			||||||
            hex_map[cl - 1] = to_hex(cmap(colornorm(cl)))
 | 
					 | 
				
			||||||
            rgb_map[cl - 1] = to_rgb(cmap(colornorm(cl)))
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            hex_map[cl] = to_hex(cmap(colornorm(cl)))
 | 
					 | 
				
			||||||
            rgb_map[cl] = to_rgb(cmap(colornorm(cl)))
 | 
					 | 
				
			||||||
    if tikz:
 | 
					 | 
				
			||||||
        for k, v in rgb_map.items():
 | 
					 | 
				
			||||||
            print(f"\\definecolor{{color-{k}}}{{rgb}}{{{v[0]},{v[1]},{v[2]}}}")
 | 
					 | 
				
			||||||
    if form == "hex":
 | 
					 | 
				
			||||||
        return hex_map
 | 
					 | 
				
			||||||
    elif form == "rgb":
 | 
					 | 
				
			||||||
        return rgb_map
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        return hex_map
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_legend_handles(labels, marker="dots", zero_indexed=False):
 | 
					def rgb_to_hex(rgb_values):
 | 
				
			||||||
    """Return matplotlib legend handles and colors."""
 | 
					    for v in rgb_values:
 | 
				
			||||||
    handles = list()
 | 
					        c = "%02x%02x%02x" % tuple(v)
 | 
				
			||||||
    n = len(labels)
 | 
					        yield c
 | 
				
			||||||
    colors = color_scheme(n,
 | 
					 | 
				
			||||||
                          cmap="viridis",
 | 
					 | 
				
			||||||
                          form="hex",
 | 
					 | 
				
			||||||
                          zero_indexed=zero_indexed)
 | 
					 | 
				
			||||||
    for label, color in zip(labels, colors.values()):
 | 
					 | 
				
			||||||
        if marker == "dots":
 | 
					 | 
				
			||||||
            handle = mlines.Line2D(
 | 
					 | 
				
			||||||
                [],
 | 
					 | 
				
			||||||
                [],
 | 
					 | 
				
			||||||
                color="white",
 | 
					 | 
				
			||||||
                markerfacecolor=color,
 | 
					 | 
				
			||||||
                marker="o",
 | 
					 | 
				
			||||||
                markersize=10,
 | 
					 | 
				
			||||||
                markeredgecolor="k",
 | 
					 | 
				
			||||||
                label=label,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            handle = mlines.Line2D([], [],
 | 
					 | 
				
			||||||
                                   color=color,
 | 
					 | 
				
			||||||
                                   marker="",
 | 
					 | 
				
			||||||
                                   markersize=15,
 | 
					 | 
				
			||||||
                                   label=label)
 | 
					 | 
				
			||||||
            handles.append(handle)
 | 
					 | 
				
			||||||
    return handles, colors
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										104
									
								
								prototorch/utils/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								prototorch/utils/utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,104 @@
 | 
				
			|||||||
 | 
					"""ProtoFlow utilities"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					from collections.abc import Iterable
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader, Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
 | 
				
			||||||
 | 
					    if x is not None:
 | 
				
			||||||
 | 
					        x_shift = border * np.ptp(x[:, 0])
 | 
				
			||||||
 | 
					        y_shift = border * np.ptp(x[:, 1])
 | 
				
			||||||
 | 
					        x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
 | 
				
			||||||
 | 
					        y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        x_min, x_max = -border, border
 | 
				
			||||||
 | 
					        y_min, y_max = -border, border
 | 
				
			||||||
 | 
					    xx, yy = np.meshgrid(np.linspace(x_min, x_max, resolution),
 | 
				
			||||||
 | 
					                         np.linspace(y_min, y_max, resolution))
 | 
				
			||||||
 | 
					    mesh = np.c_[xx.ravel(), yy.ravel()]
 | 
				
			||||||
 | 
					    return mesh, xx, yy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def distribution_from_list(list_dist: list[int],
 | 
				
			||||||
 | 
					                           clabels: Iterable[int] = None):
 | 
				
			||||||
 | 
					    clabels = clabels or list(range(len(list_dist)))
 | 
				
			||||||
 | 
					    distribution = dict(zip(clabels, list_dist))
 | 
				
			||||||
 | 
					    return distribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_distribution(user_distribution,
 | 
				
			||||||
 | 
					                       clabels: Iterable[int] = None) -> dict[int, int]:
 | 
				
			||||||
 | 
					    """Parse user-provided distribution.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Return a dictionary with integer keys that represent the class labels and
 | 
				
			||||||
 | 
					    values that denote the number of components/prototypes with that class
 | 
				
			||||||
 | 
					    label.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The argument `user_distribution` could be any one of a number of allowed
 | 
				
			||||||
 | 
					    formats. If it is a Python list, it is assumed that there are as many
 | 
				
			||||||
 | 
					    entries in this list as there are classes, and the value at each index of
 | 
				
			||||||
 | 
					    this list describes the number of prototypes for that particular class. So,
 | 
				
			||||||
 | 
					    [1, 1, 1] implies that we have three classes with one prototype per class.
 | 
				
			||||||
 | 
					    If it is a Python tuple, a shorthand of (num_classes, prototypes_per_class)
 | 
				
			||||||
 | 
					    is assumed. If it is a Python dictionary, the key-value pairs describe the
 | 
				
			||||||
 | 
					    class label and the number of prototypes for that class respectively. So,
 | 
				
			||||||
 | 
					    {0: 2, 1: 2, 2: 2} implies that we have three classes with labels {1, 2,
 | 
				
			||||||
 | 
					    3}, each equipped with two prototypes. If however, the dictionary contains
 | 
				
			||||||
 | 
					    the keys "num_classes" and "per_class", they are parsed to use their values
 | 
				
			||||||
 | 
					    as one might expect.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if isinstance(user_distribution, dict):
 | 
				
			||||||
 | 
					        if "num_classes" in user_distribution.keys():
 | 
				
			||||||
 | 
					            num_classes = int(user_distribution["num_classes"])
 | 
				
			||||||
 | 
					            per_class = int(user_distribution["per_class"])
 | 
				
			||||||
 | 
					            return distribution_from_list([per_class] * num_classes, clabels)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return user_distribution
 | 
				
			||||||
 | 
					    elif isinstance(user_distribution, tuple):
 | 
				
			||||||
 | 
					        assert len(user_distribution) == 2
 | 
				
			||||||
 | 
					        num_classes, per_class = user_distribution
 | 
				
			||||||
 | 
					        num_classes, per_class = int(num_classes), int(per_class)
 | 
				
			||||||
 | 
					        return distribution_from_list([per_class] * num_classes, clabels)
 | 
				
			||||||
 | 
					    elif isinstance(user_distribution, list):
 | 
				
			||||||
 | 
					        return distribution_from_list(user_distribution, clabels)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        msg = f"`distribution` was not understood." \
 | 
				
			||||||
 | 
					            f"You have provided: {user_distribution}."
 | 
				
			||||||
 | 
					        raise ValueError(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
 | 
				
			||||||
 | 
					    """Return data and target as torch tensors."""
 | 
				
			||||||
 | 
					    if isinstance(data_arg, Dataset):
 | 
				
			||||||
 | 
					        if hasattr(data_arg, "__len__"):
 | 
				
			||||||
 | 
					            ds_size = len(data_arg)  # type: ignore
 | 
				
			||||||
 | 
					            loader = DataLoader(data_arg, batch_size=ds_size)
 | 
				
			||||||
 | 
					            data, targets = next(iter(loader))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            emsg = f"Dataset {data_arg} is not sized (`__len__` unimplemented)."
 | 
				
			||||||
 | 
					            raise TypeError(emsg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif isinstance(data_arg, DataLoader):
 | 
				
			||||||
 | 
					        data = torch.tensor([])
 | 
				
			||||||
 | 
					        targets = torch.tensor([])
 | 
				
			||||||
 | 
					        for x, y in data_arg:
 | 
				
			||||||
 | 
					            data = torch.cat([data, x])
 | 
				
			||||||
 | 
					            targets = torch.cat([targets, y])
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        assert len(data_arg) == 2
 | 
				
			||||||
 | 
					        data, targets = data_arg
 | 
				
			||||||
 | 
					        if not isinstance(data, torch.Tensor):
 | 
				
			||||||
 | 
					            wmsg = f"Converting data to {torch.Tensor}..."
 | 
				
			||||||
 | 
					            warnings.warn(wmsg)
 | 
				
			||||||
 | 
					            data = torch.Tensor(data)
 | 
				
			||||||
 | 
					        if not isinstance(targets, torch.LongTensor):
 | 
				
			||||||
 | 
					            wmsg = f"Converting targets to {torch.LongTensor}..."
 | 
				
			||||||
 | 
					            warnings.warn(wmsg)
 | 
				
			||||||
 | 
					            targets = torch.LongTensor(targets)
 | 
				
			||||||
 | 
					    return data, targets
 | 
				
			||||||
							
								
								
									
										15
									
								
								setup.cfg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								setup.cfg
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					[pylint]
 | 
				
			||||||
 | 
					disable =
 | 
				
			||||||
 | 
					    too-many-arguments,
 | 
				
			||||||
 | 
					    too-few-public-methods,
 | 
				
			||||||
 | 
					    fixme,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[pycodestyle]
 | 
				
			||||||
 | 
					max-line-length = 79
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[isort]
 | 
				
			||||||
 | 
					multi_line_output = 3
 | 
				
			||||||
 | 
					include_trailing_comma = True
 | 
				
			||||||
 | 
					force_grid_wrap = 3
 | 
				
			||||||
 | 
					use_parentheses = True
 | 
				
			||||||
 | 
					line_length = 79
 | 
				
			||||||
							
								
								
									
										3
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								setup.py
									
									
									
									
									
								
							@@ -61,8 +61,9 @@ setup(
 | 
				
			|||||||
    license="MIT",
 | 
					    license="MIT",
 | 
				
			||||||
    install_requires=INSTALL_REQUIRES,
 | 
					    install_requires=INSTALL_REQUIRES,
 | 
				
			||||||
    extras_require={
 | 
					    extras_require={
 | 
				
			||||||
        "docs": DOCS,
 | 
					 | 
				
			||||||
        "datasets": DATASETS,
 | 
					        "datasets": DATASETS,
 | 
				
			||||||
 | 
					        "dev": DEV,
 | 
				
			||||||
 | 
					        "docs": DOCS,
 | 
				
			||||||
        "examples": EXAMPLES,
 | 
					        "examples": EXAMPLES,
 | 
				
			||||||
        "tests": TESTS,
 | 
					        "tests": TESTS,
 | 
				
			||||||
        "all": ALL,
 | 
					        "all": ALL,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,26 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch components test suite."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_labcomps_zeros_init():
 | 
					 | 
				
			||||||
    protos = torch.zeros(3, 2)
 | 
					 | 
				
			||||||
    c = pt.components.LabeledComponents(
 | 
					 | 
				
			||||||
        distribution=[1, 1, 1],
 | 
					 | 
				
			||||||
        initializer=pt.components.Zeros(2),
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    assert (c.components == protos).any() == True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_labcomps_warmstart():
 | 
					 | 
				
			||||||
    protos = torch.randn(3, 2)
 | 
					 | 
				
			||||||
    plabels = torch.tensor([1, 2, 3])
 | 
					 | 
				
			||||||
    c = pt.components.LabeledComponents(
 | 
					 | 
				
			||||||
        distribution=[1, 1, 1],
 | 
					 | 
				
			||||||
        initializer=None,
 | 
					 | 
				
			||||||
        initialized_components=[protos, plabels],
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    assert (c.components == protos).any() == True
 | 
					 | 
				
			||||||
    assert (c.component_labels == plabels).any() == True
 | 
					 | 
				
			||||||
							
								
								
									
										760
									
								
								tests/test_core.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										760
									
								
								tests/test_core.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,760 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch core test suite"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					from prototorch.utils import parse_distribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Utils
 | 
				
			||||||
 | 
					def test_parse_distribution_dict_0():
 | 
				
			||||||
 | 
					    distribution = {"num_classes": 1, "per_class": 0}
 | 
				
			||||||
 | 
					    distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					    assert distribution == {0: 0}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_parse_distribution_dict_1():
 | 
				
			||||||
 | 
					    distribution = dict(num_classes=3, per_class=2)
 | 
				
			||||||
 | 
					    distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					    assert distribution == {0: 2, 1: 2, 2: 2}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_parse_distribution_dict_2():
 | 
				
			||||||
 | 
					    distribution = {0: 1, 2: 2, -1: 3}
 | 
				
			||||||
 | 
					    distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					    assert distribution == {0: 1, 2: 2, -1: 3}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_parse_distribution_tuple():
 | 
				
			||||||
 | 
					    distribution = (2, 3)
 | 
				
			||||||
 | 
					    distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					    assert distribution == {0: 3, 1: 3}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_parse_distribution_list():
 | 
				
			||||||
 | 
					    distribution = [1, 1, 0, 2]
 | 
				
			||||||
 | 
					    distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					    assert distribution == {0: 1, 1: 1, 2: 0, 3: 2}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_parse_distribution_custom_labels():
 | 
				
			||||||
 | 
					    distribution = [1, 1, 0, 2]
 | 
				
			||||||
 | 
					    clabels = [1, 2, 5, 3]
 | 
				
			||||||
 | 
					    distribution = parse_distribution(distribution, clabels)
 | 
				
			||||||
 | 
					    assert distribution == {1: 1, 2: 1, 5: 0, 3: 2}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Components initializers
 | 
				
			||||||
 | 
					def test_literal_comp_generate():
 | 
				
			||||||
 | 
					    protos = torch.rand(4, 3, 5, 5)
 | 
				
			||||||
 | 
					    c = pt.initializers.LiteralCompInitializer(protos)
 | 
				
			||||||
 | 
					    components = c.generate([])
 | 
				
			||||||
 | 
					    assert torch.allclose(components, protos)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_literal_comp_generate_from_list():
 | 
				
			||||||
 | 
					    protos = [[0, 1], [2, 3], [4, 5]]
 | 
				
			||||||
 | 
					    c = pt.initializers.LiteralCompInitializer(protos)
 | 
				
			||||||
 | 
					    with pytest.warns(UserWarning):
 | 
				
			||||||
 | 
					        components = c.generate([])
 | 
				
			||||||
 | 
					    assert torch.allclose(components, torch.Tensor(protos))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_shape_aware_raises_error():
 | 
				
			||||||
 | 
					    with pytest.raises(TypeError):
 | 
				
			||||||
 | 
					        _ = pt.initializers.ShapeAwareCompInitializer(shape=(2, ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_data_aware_comp_generate():
 | 
				
			||||||
 | 
					    protos = torch.rand(4, 3, 5, 5)
 | 
				
			||||||
 | 
					    c = pt.initializers.DataAwareCompInitializer(protos)
 | 
				
			||||||
 | 
					    components = c.generate(num_components="IgnoreMe!")
 | 
				
			||||||
 | 
					    assert torch.allclose(components, protos)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_class_aware_comp_generate():
 | 
				
			||||||
 | 
					    protos = torch.rand(4, 2, 3, 5, 5)
 | 
				
			||||||
 | 
					    plabels = torch.tensor([0, 0, 1, 1]).long()
 | 
				
			||||||
 | 
					    c = pt.initializers.ClassAwareCompInitializer([protos, plabels])
 | 
				
			||||||
 | 
					    components = c.generate(distribution=[])
 | 
				
			||||||
 | 
					    assert torch.allclose(components, protos)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_zeros_comp_generate():
 | 
				
			||||||
 | 
					    shape = (3, 5, 5)
 | 
				
			||||||
 | 
					    c = pt.initializers.ZerosCompInitializer(shape)
 | 
				
			||||||
 | 
					    components = c.generate(num_components=4)
 | 
				
			||||||
 | 
					    assert torch.allclose(components, torch.zeros(4, 3, 5, 5))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_ones_comp_generate():
 | 
				
			||||||
 | 
					    c = pt.initializers.OnesCompInitializer(2)
 | 
				
			||||||
 | 
					    components = c.generate(num_components=3)
 | 
				
			||||||
 | 
					    assert torch.allclose(components, torch.ones(3, 2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_fill_value_comp_generate():
 | 
				
			||||||
 | 
					    c = pt.initializers.FillValueCompInitializer(2, 0.0)
 | 
				
			||||||
 | 
					    components = c.generate(num_components=3)
 | 
				
			||||||
 | 
					    assert torch.allclose(components, torch.zeros(3, 2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_uniform_comp_generate_min_max_bound():
 | 
				
			||||||
 | 
					    c = pt.initializers.UniformCompInitializer(2, -1.0, 1.0)
 | 
				
			||||||
 | 
					    components = c.generate(num_components=1024)
 | 
				
			||||||
 | 
					    assert components.min() >= -1.0
 | 
				
			||||||
 | 
					    assert components.max() <= 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_random_comp_generate_mean():
 | 
				
			||||||
 | 
					    c = pt.initializers.RandomNormalCompInitializer(2, -1.0)
 | 
				
			||||||
 | 
					    components = c.generate(num_components=1024)
 | 
				
			||||||
 | 
					    assert torch.allclose(components.mean(),
 | 
				
			||||||
 | 
					                          torch.tensor(-1.0),
 | 
				
			||||||
 | 
					                          rtol=1e-05,
 | 
				
			||||||
 | 
					                          atol=1e-01)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_comp_generate_0_components():
 | 
				
			||||||
 | 
					    c = pt.initializers.ZerosCompInitializer(2)
 | 
				
			||||||
 | 
					    _ = c.generate(num_components=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_stratified_mean_comp_generate():
 | 
				
			||||||
 | 
					    # yapf: disable
 | 
				
			||||||
 | 
					    x = torch.Tensor(
 | 
				
			||||||
 | 
					        [[0,  -1, -2],
 | 
				
			||||||
 | 
					         [10, 11, 12],
 | 
				
			||||||
 | 
					         [0,   0,  0],
 | 
				
			||||||
 | 
					         [2,   2,  2]])
 | 
				
			||||||
 | 
					    y = torch.LongTensor([0, 0, 1, 1])
 | 
				
			||||||
 | 
					    desired = torch.Tensor(
 | 
				
			||||||
 | 
					        [[5.0, 5.0, 5.0],
 | 
				
			||||||
 | 
					         [1.0, 1.0, 1.0]])
 | 
				
			||||||
 | 
					    # yapf: enable
 | 
				
			||||||
 | 
					    c = pt.initializers.StratifiedMeanCompInitializer(data=[x, y])
 | 
				
			||||||
 | 
					    actual = c.generate([1, 1])
 | 
				
			||||||
 | 
					    assert torch.allclose(actual, desired)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_stratified_selection_comp_generate():
 | 
				
			||||||
 | 
					    # yapf: disable
 | 
				
			||||||
 | 
					    x = torch.Tensor(
 | 
				
			||||||
 | 
					        [[0, 0, 0],
 | 
				
			||||||
 | 
					         [1, 1, 1],
 | 
				
			||||||
 | 
					         [0, 0, 0],
 | 
				
			||||||
 | 
					         [1, 1, 1]])
 | 
				
			||||||
 | 
					    y = torch.LongTensor([0, 1, 0, 1])
 | 
				
			||||||
 | 
					    desired = torch.Tensor(
 | 
				
			||||||
 | 
					        [[0, 0, 0],
 | 
				
			||||||
 | 
					         [1, 1, 1]])
 | 
				
			||||||
 | 
					    # yapf: enable
 | 
				
			||||||
 | 
					    c = pt.initializers.StratifiedSelectionCompInitializer(data=[x, y])
 | 
				
			||||||
 | 
					    actual = c.generate([1, 1])
 | 
				
			||||||
 | 
					    assert torch.allclose(actual, desired)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Labels initializers
 | 
				
			||||||
 | 
					def test_literal_labels_init():
 | 
				
			||||||
 | 
					    l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2])
 | 
				
			||||||
 | 
					    with pytest.warns(UserWarning):
 | 
				
			||||||
 | 
					        labels = l.generate([])
 | 
				
			||||||
 | 
					    assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_labels_init_from_list():
 | 
				
			||||||
 | 
					    l = pt.initializers.LabelsInitializer()
 | 
				
			||||||
 | 
					    components = l.generate(distribution=[1, 1, 1])
 | 
				
			||||||
 | 
					    assert torch.allclose(components, torch.LongTensor([0, 1, 2]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_labels_init_from_tuple_legal():
 | 
				
			||||||
 | 
					    l = pt.initializers.LabelsInitializer()
 | 
				
			||||||
 | 
					    components = l.generate(distribution=(3, 1))
 | 
				
			||||||
 | 
					    assert torch.allclose(components, torch.LongTensor([0, 1, 2]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_labels_init_from_tuple_illegal():
 | 
				
			||||||
 | 
					    l = pt.initializers.LabelsInitializer()
 | 
				
			||||||
 | 
					    with pytest.raises(AssertionError):
 | 
				
			||||||
 | 
					        _ = l.generate(distribution=(1, 1, 1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_data_aware_labels_init():
 | 
				
			||||||
 | 
					    data, targets = [0, 1, 2, 3], [0, 0, 1, 1]
 | 
				
			||||||
 | 
					    ds = pt.datasets.NumpyDataset(data, targets)
 | 
				
			||||||
 | 
					    l = pt.initializers.DataAwareLabelsInitializer(ds)
 | 
				
			||||||
 | 
					    labels = l.generate([])
 | 
				
			||||||
 | 
					    assert torch.allclose(labels, torch.LongTensor(targets))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Reasonings initializers
 | 
				
			||||||
 | 
					def test_literal_reasonings_init():
 | 
				
			||||||
 | 
					    r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2])
 | 
				
			||||||
 | 
					    with pytest.warns(UserWarning):
 | 
				
			||||||
 | 
					        reasonings = r.generate([])
 | 
				
			||||||
 | 
					    assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_random_reasonings_init():
 | 
				
			||||||
 | 
					    r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8)
 | 
				
			||||||
 | 
					    reasonings = r.generate(distribution=[0, 1])
 | 
				
			||||||
 | 
					    assert torch.numel(reasonings) == 1 * 2 * 2
 | 
				
			||||||
 | 
					    assert reasonings.min() >= 0.2
 | 
				
			||||||
 | 
					    assert reasonings.max() <= 0.8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_zeros_reasonings_init():
 | 
				
			||||||
 | 
					    r = pt.initializers.ZerosReasoningsInitializer()
 | 
				
			||||||
 | 
					    reasonings = r.generate(distribution=[0, 1])
 | 
				
			||||||
 | 
					    assert torch.allclose(reasonings, torch.zeros(1, 2, 2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_ones_reasonings_init():
 | 
				
			||||||
 | 
					    r = pt.initializers.ZerosReasoningsInitializer()
 | 
				
			||||||
 | 
					    reasonings = r.generate(distribution=[1, 2, 3])
 | 
				
			||||||
 | 
					    assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_pure_positive_reasonings_init_one_per_class():
 | 
				
			||||||
 | 
					    r = pt.initializers.PurePositiveReasoningsInitializer(
 | 
				
			||||||
 | 
					        components_first=False)
 | 
				
			||||||
 | 
					    reasonings = r.generate(distribution=(4, 1))
 | 
				
			||||||
 | 
					    assert torch.allclose(reasonings[0], torch.eye(4))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_pure_positive_reasonings_init_unrepresented_classes():
 | 
				
			||||||
 | 
					    r = pt.initializers.PurePositiveReasoningsInitializer()
 | 
				
			||||||
 | 
					    reasonings = r.generate(distribution=[9, 0, 0, 0])
 | 
				
			||||||
 | 
					    assert reasonings.shape[0] == 9
 | 
				
			||||||
 | 
					    assert reasonings.shape[1] == 4
 | 
				
			||||||
 | 
					    assert reasonings.shape[2] == 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_random_reasonings_init_channels_not_first():
 | 
				
			||||||
 | 
					    r = pt.initializers.RandomReasoningsInitializer(components_first=False)
 | 
				
			||||||
 | 
					    reasonings = r.generate(distribution=[0, 0, 0, 1])
 | 
				
			||||||
 | 
					    assert reasonings.shape[0] == 2
 | 
				
			||||||
 | 
					    assert reasonings.shape[1] == 4
 | 
				
			||||||
 | 
					    assert reasonings.shape[2] == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Transform initializers
 | 
				
			||||||
 | 
					def test_eye_transform_init_square():
 | 
				
			||||||
 | 
					    t = pt.initializers.EyeTransformInitializer()
 | 
				
			||||||
 | 
					    I = t.generate(3, 3)
 | 
				
			||||||
 | 
					    assert torch.allclose(I, torch.eye(3))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_eye_transform_init_narrow():
 | 
				
			||||||
 | 
					    t = pt.initializers.EyeTransformInitializer()
 | 
				
			||||||
 | 
					    actual = t.generate(3, 2)
 | 
				
			||||||
 | 
					    desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
 | 
				
			||||||
 | 
					    assert torch.allclose(actual, desired)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_eye_transform_init_wide():
 | 
				
			||||||
 | 
					    t = pt.initializers.EyeTransformInitializer()
 | 
				
			||||||
 | 
					    actual = t.generate(2, 3)
 | 
				
			||||||
 | 
					    desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
 | 
				
			||||||
 | 
					    assert torch.allclose(actual, desired)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Transforms
 | 
				
			||||||
 | 
					def test_linear_transform():
 | 
				
			||||||
 | 
					    l = pt.transforms.LinearTransform(2, 4)
 | 
				
			||||||
 | 
					    actual = l.weights
 | 
				
			||||||
 | 
					    desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
 | 
				
			||||||
 | 
					    assert torch.allclose(actual, desired)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_linear_transform_zeros_init():
 | 
				
			||||||
 | 
					    l = pt.transforms.LinearTransform(
 | 
				
			||||||
 | 
					        in_dim=2,
 | 
				
			||||||
 | 
					        out_dim=4,
 | 
				
			||||||
 | 
					        initializer=pt.initializers.ZerosLinearTransformInitializer(),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    actual = l.weights
 | 
				
			||||||
 | 
					    desired = torch.zeros(2, 4)
 | 
				
			||||||
 | 
					    assert torch.allclose(actual, desired)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_linear_transform_out_dim_first():
 | 
				
			||||||
 | 
					    l = pt.transforms.LinearTransform(
 | 
				
			||||||
 | 
					        in_dim=2,
 | 
				
			||||||
 | 
					        out_dim=4,
 | 
				
			||||||
 | 
					        initializer=pt.initializers.OLTI(out_dim_first=True),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    assert l.weights.shape[0] == 4
 | 
				
			||||||
 | 
					    assert l.weights.shape[1] == 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Components
 | 
				
			||||||
 | 
					def test_components_no_initializer():
 | 
				
			||||||
 | 
					    with pytest.raises(TypeError):
 | 
				
			||||||
 | 
					        _ = pt.components.Components(3, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_components_no_num_components():
 | 
				
			||||||
 | 
					    with pytest.raises(TypeError):
 | 
				
			||||||
 | 
					        _ = pt.components.Components(initializer=pt.initializers.OCI(2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_components_none_num_components():
 | 
				
			||||||
 | 
					    with pytest.raises(TypeError):
 | 
				
			||||||
 | 
					        _ = pt.components.Components(None, initializer=pt.initializers.OCI(2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_components_no_args():
 | 
				
			||||||
 | 
					    with pytest.raises(TypeError):
 | 
				
			||||||
 | 
					        _ = pt.components.Components()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_components_zeros_init():
 | 
				
			||||||
 | 
					    c = pt.components.Components(3, pt.initializers.ZCI(2))
 | 
				
			||||||
 | 
					    assert torch.allclose(c.components, torch.zeros(3, 2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_labeled_components_dict_init():
 | 
				
			||||||
 | 
					    c = pt.components.LabeledComponents({0: 3}, pt.initializers.OCI(2))
 | 
				
			||||||
 | 
					    assert torch.allclose(c.components, torch.ones(3, 2))
 | 
				
			||||||
 | 
					    assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_labeled_components_list_init():
 | 
				
			||||||
 | 
					    c = pt.components.LabeledComponents([3], pt.initializers.OCI(2))
 | 
				
			||||||
 | 
					    assert torch.allclose(c.components, torch.ones(3, 2))
 | 
				
			||||||
 | 
					    assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_labeled_components_tuple_init():
 | 
				
			||||||
 | 
					    c = pt.components.LabeledComponents({0: 1, 1: 2}, pt.initializers.OCI(2))
 | 
				
			||||||
 | 
					    assert torch.allclose(c.components, torch.ones(3, 2))
 | 
				
			||||||
 | 
					    assert torch.allclose(c.labels, torch.LongTensor([0, 1, 1]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Labels
 | 
				
			||||||
 | 
					def test_standalone_labels_dict_init():
 | 
				
			||||||
 | 
					    l = pt.components.Labels({0: 3})
 | 
				
			||||||
 | 
					    assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_standalone_labels_list_init():
 | 
				
			||||||
 | 
					    l = pt.components.Labels([3])
 | 
				
			||||||
 | 
					    assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_standalone_labels_tuple_init():
 | 
				
			||||||
 | 
					    l = pt.components.Labels({0: 1, 1: 2})
 | 
				
			||||||
 | 
					    assert torch.allclose(l.labels, torch.LongTensor([0, 1, 1]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Losses
 | 
				
			||||||
 | 
					def test_glvq_loss_int_labels():
 | 
				
			||||||
 | 
					    d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
 | 
				
			||||||
 | 
					    labels = torch.tensor([0, 1])
 | 
				
			||||||
 | 
					    targets = torch.ones(100)
 | 
				
			||||||
 | 
					    batch_loss = pt.losses.glvq_loss(distances=d,
 | 
				
			||||||
 | 
					                                     target_labels=targets,
 | 
				
			||||||
 | 
					                                     prototype_labels=labels)
 | 
				
			||||||
 | 
					    loss_value = torch.sum(batch_loss, dim=0)
 | 
				
			||||||
 | 
					    assert loss_value == -100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_glvq_loss_one_hot_labels():
 | 
				
			||||||
 | 
					    d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
 | 
				
			||||||
 | 
					    labels = torch.tensor([[0, 1], [1, 0]])
 | 
				
			||||||
 | 
					    wl = torch.tensor([1, 0])
 | 
				
			||||||
 | 
					    targets = torch.stack([wl for _ in range(100)], dim=0)
 | 
				
			||||||
 | 
					    batch_loss = pt.losses.glvq_loss(distances=d,
 | 
				
			||||||
 | 
					                                     target_labels=targets,
 | 
				
			||||||
 | 
					                                     prototype_labels=labels)
 | 
				
			||||||
 | 
					    loss_value = torch.sum(batch_loss, dim=0)
 | 
				
			||||||
 | 
					    assert loss_value == -100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_glvq_loss_one_hot_unequal():
 | 
				
			||||||
 | 
					    dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
 | 
				
			||||||
 | 
					    d = torch.stack(dlist, dim=1)
 | 
				
			||||||
 | 
					    labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
 | 
				
			||||||
 | 
					    wl = torch.tensor([1, 0])
 | 
				
			||||||
 | 
					    targets = torch.stack([wl for _ in range(100)], dim=0)
 | 
				
			||||||
 | 
					    batch_loss = pt.losses.glvq_loss(distances=d,
 | 
				
			||||||
 | 
					                                     target_labels=targets,
 | 
				
			||||||
 | 
					                                     prototype_labels=labels)
 | 
				
			||||||
 | 
					    loss_value = torch.sum(batch_loss, dim=0)
 | 
				
			||||||
 | 
					    assert loss_value == -100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Activations
 | 
				
			||||||
 | 
					class TestActivations(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.flist = ["identity", "sigmoid_beta", "swish_beta"]
 | 
				
			||||||
 | 
					        self.x = torch.randn(1024, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_registry(self):
 | 
				
			||||||
 | 
					        self.assertIsNotNone(pt.nn.ACTIVATIONS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_funcname_deserialization(self):
 | 
				
			||||||
 | 
					        for funcname in self.flist:
 | 
				
			||||||
 | 
					            f = pt.nn.get_activation(funcname)
 | 
				
			||||||
 | 
					            iscallable = callable(f)
 | 
				
			||||||
 | 
					            self.assertTrue(iscallable)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_callable_deserialization(self):
 | 
				
			||||||
 | 
					        def dummy(x, **kwargs):
 | 
				
			||||||
 | 
					            return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for f in [dummy, lambda x: x]:
 | 
				
			||||||
 | 
					            f = pt.nn.get_activation(f)
 | 
				
			||||||
 | 
					            iscallable = callable(f)
 | 
				
			||||||
 | 
					            self.assertTrue(iscallable)
 | 
				
			||||||
 | 
					            self.assertEqual(1, f(1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_unknown_deserialization(self):
 | 
				
			||||||
 | 
					        for funcname in ["blubb", "foobar"]:
 | 
				
			||||||
 | 
					            with self.assertRaises(NameError):
 | 
				
			||||||
 | 
					                _ = pt.nn.get_activation(funcname)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_identity(self):
 | 
				
			||||||
 | 
					        actual = pt.nn.identity(self.x)
 | 
				
			||||||
 | 
					        desired = self.x
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_sigmoid_beta1(self):
 | 
				
			||||||
 | 
					        actual = pt.nn.sigmoid_beta(self.x, beta=1.0)
 | 
				
			||||||
 | 
					        desired = torch.sigmoid(self.x)
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_swish_beta1(self):
 | 
				
			||||||
 | 
					        actual = pt.nn.swish_beta(self.x, beta=1.0)
 | 
				
			||||||
 | 
					        desired = self.x * torch.sigmoid(self.x)
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        del self.x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Competitions
 | 
				
			||||||
 | 
					class TestCompetitions(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_wtac(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 1, 2, 3])
 | 
				
			||||||
 | 
					        competition_layer = pt.competitions.WTAC()
 | 
				
			||||||
 | 
					        actual = competition_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([2, 0])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_wtac_unequal_dist(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 1, 1])
 | 
				
			||||||
 | 
					        competition_layer = pt.competitions.WTAC()
 | 
				
			||||||
 | 
					        actual = competition_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([0, 1])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_wtac_one_hot(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([[0, 1], [1, 0]])
 | 
				
			||||||
 | 
					        competition_layer = pt.competitions.WTAC()
 | 
				
			||||||
 | 
					        actual = competition_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[0, 1], [1, 0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_knnc_k1(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 1, 2, 3])
 | 
				
			||||||
 | 
					        competition_layer = pt.competitions.KNNC(k=1)
 | 
				
			||||||
 | 
					        actual = competition_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([2, 0])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Pooling
 | 
				
			||||||
 | 
					class TestPooling(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_min(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 0, 1, 2])
 | 
				
			||||||
 | 
					        pooling_layer = pt.pooling.StratifiedMinPooling()
 | 
				
			||||||
 | 
					        actual = pooling_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_min_one_hot(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 0, 1, 2])
 | 
				
			||||||
 | 
					        labels = torch.eye(3)[labels]
 | 
				
			||||||
 | 
					        pooling_layer = pt.pooling.StratifiedMinPooling()
 | 
				
			||||||
 | 
					        actual = pooling_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_min_trivial(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 1, 2])
 | 
				
			||||||
 | 
					        pooling_layer = pt.pooling.StratifiedMinPooling()
 | 
				
			||||||
 | 
					        actual = pooling_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_max(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 0, 3, 2, 0])
 | 
				
			||||||
 | 
					        pooling_layer = pt.pooling.StratifiedMaxPooling()
 | 
				
			||||||
 | 
					        actual = pooling_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_max_one_hot(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 0, 2, 1, 0])
 | 
				
			||||||
 | 
					        labels = torch.nn.functional.one_hot(labels, num_classes=3)
 | 
				
			||||||
 | 
					        pooling_layer = pt.pooling.StratifiedMaxPooling()
 | 
				
			||||||
 | 
					        actual = pooling_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_sum(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
 | 
				
			||||||
 | 
					        labels = torch.LongTensor([0, 0, 1, 2])
 | 
				
			||||||
 | 
					        pooling_layer = pt.pooling.StratifiedSumPooling()
 | 
				
			||||||
 | 
					        actual = pooling_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_sum_one_hot(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 0, 1, 2])
 | 
				
			||||||
 | 
					        labels = torch.eye(3)[labels]
 | 
				
			||||||
 | 
					        pooling_layer = pt.pooling.StratifiedSumPooling()
 | 
				
			||||||
 | 
					        actual = pooling_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_prod(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 0, 3, 2, 0])
 | 
				
			||||||
 | 
					        pooling_layer = pt.pooling.StratifiedProdPooling()
 | 
				
			||||||
 | 
					        actual = pooling_layer(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Distances
 | 
				
			||||||
 | 
					class TestDistances(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.nx, self.mx = 32, 2048
 | 
				
			||||||
 | 
					        self.ny, self.my = 8, 2048
 | 
				
			||||||
 | 
					        self.x = torch.randn(self.nx, self.mx)
 | 
				
			||||||
 | 
					        self.y = torch.randn(self.ny, self.my)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_manhattan(self):
 | 
				
			||||||
 | 
					        actual = pt.distances.lpnorm_distance(self.x, self.y, p=1)
 | 
				
			||||||
 | 
					        desired = torch.empty(self.nx, self.ny)
 | 
				
			||||||
 | 
					        for i in range(self.nx):
 | 
				
			||||||
 | 
					            for j in range(self.ny):
 | 
				
			||||||
 | 
					                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
				
			||||||
 | 
					                    self.x[i].reshape(1, -1),
 | 
				
			||||||
 | 
					                    self.y[j].reshape(1, -1),
 | 
				
			||||||
 | 
					                    p=1,
 | 
				
			||||||
 | 
					                    keepdim=False,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=2)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_euclidean(self):
 | 
				
			||||||
 | 
					        actual = pt.distances.euclidean_distance(self.x, self.y)
 | 
				
			||||||
 | 
					        desired = torch.empty(self.nx, self.ny)
 | 
				
			||||||
 | 
					        for i in range(self.nx):
 | 
				
			||||||
 | 
					            for j in range(self.ny):
 | 
				
			||||||
 | 
					                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
				
			||||||
 | 
					                    self.x[i].reshape(1, -1),
 | 
				
			||||||
 | 
					                    self.y[j].reshape(1, -1),
 | 
				
			||||||
 | 
					                    p=2,
 | 
				
			||||||
 | 
					                    keepdim=False,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=3)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_squared_euclidean(self):
 | 
				
			||||||
 | 
					        actual = pt.distances.squared_euclidean_distance(self.x, self.y)
 | 
				
			||||||
 | 
					        desired = torch.empty(self.nx, self.ny)
 | 
				
			||||||
 | 
					        for i in range(self.nx):
 | 
				
			||||||
 | 
					            for j in range(self.ny):
 | 
				
			||||||
 | 
					                desired[i][j] = (torch.nn.functional.pairwise_distance(
 | 
				
			||||||
 | 
					                    self.x[i].reshape(1, -1),
 | 
				
			||||||
 | 
					                    self.y[j].reshape(1, -1),
 | 
				
			||||||
 | 
					                    p=2,
 | 
				
			||||||
 | 
					                    keepdim=False,
 | 
				
			||||||
 | 
					                )**2)
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=2)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_lpnorm_p0(self):
 | 
				
			||||||
 | 
					        actual = pt.distances.lpnorm_distance(self.x, self.y, p=0)
 | 
				
			||||||
 | 
					        desired = torch.empty(self.nx, self.ny)
 | 
				
			||||||
 | 
					        for i in range(self.nx):
 | 
				
			||||||
 | 
					            for j in range(self.ny):
 | 
				
			||||||
 | 
					                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
				
			||||||
 | 
					                    self.x[i].reshape(1, -1),
 | 
				
			||||||
 | 
					                    self.y[j].reshape(1, -1),
 | 
				
			||||||
 | 
					                    p=0,
 | 
				
			||||||
 | 
					                    keepdim=False,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=4)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_lpnorm_p2(self):
 | 
				
			||||||
 | 
					        actual = pt.distances.lpnorm_distance(self.x, self.y, p=2)
 | 
				
			||||||
 | 
					        desired = torch.empty(self.nx, self.ny)
 | 
				
			||||||
 | 
					        for i in range(self.nx):
 | 
				
			||||||
 | 
					            for j in range(self.ny):
 | 
				
			||||||
 | 
					                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
				
			||||||
 | 
					                    self.x[i].reshape(1, -1),
 | 
				
			||||||
 | 
					                    self.y[j].reshape(1, -1),
 | 
				
			||||||
 | 
					                    p=2,
 | 
				
			||||||
 | 
					                    keepdim=False,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=4)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_lpnorm_p3(self):
 | 
				
			||||||
 | 
					        actual = pt.distances.lpnorm_distance(self.x, self.y, p=3)
 | 
				
			||||||
 | 
					        desired = torch.empty(self.nx, self.ny)
 | 
				
			||||||
 | 
					        for i in range(self.nx):
 | 
				
			||||||
 | 
					            for j in range(self.ny):
 | 
				
			||||||
 | 
					                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
				
			||||||
 | 
					                    self.x[i].reshape(1, -1),
 | 
				
			||||||
 | 
					                    self.y[j].reshape(1, -1),
 | 
				
			||||||
 | 
					                    p=3,
 | 
				
			||||||
 | 
					                    keepdim=False,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=4)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_lpnorm_pinf(self):
 | 
				
			||||||
 | 
					        actual = pt.distances.lpnorm_distance(self.x, self.y, p=float("inf"))
 | 
				
			||||||
 | 
					        desired = torch.empty(self.nx, self.ny)
 | 
				
			||||||
 | 
					        for i in range(self.nx):
 | 
				
			||||||
 | 
					            for j in range(self.ny):
 | 
				
			||||||
 | 
					                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
				
			||||||
 | 
					                    self.x[i].reshape(1, -1),
 | 
				
			||||||
 | 
					                    self.y[j].reshape(1, -1),
 | 
				
			||||||
 | 
					                    p=float("inf"),
 | 
				
			||||||
 | 
					                    keepdim=False,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=4)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_omega_identity(self):
 | 
				
			||||||
 | 
					        omega = torch.eye(self.mx, self.my)
 | 
				
			||||||
 | 
					        actual = pt.distances.omega_distance(self.x, self.y, omega=omega)
 | 
				
			||||||
 | 
					        desired = torch.empty(self.nx, self.ny)
 | 
				
			||||||
 | 
					        for i in range(self.nx):
 | 
				
			||||||
 | 
					            for j in range(self.ny):
 | 
				
			||||||
 | 
					                desired[i][j] = (torch.nn.functional.pairwise_distance(
 | 
				
			||||||
 | 
					                    self.x[i].reshape(1, -1),
 | 
				
			||||||
 | 
					                    self.y[j].reshape(1, -1),
 | 
				
			||||||
 | 
					                    p=2,
 | 
				
			||||||
 | 
					                    keepdim=False,
 | 
				
			||||||
 | 
					                )**2)
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=2)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_lomega_identity(self):
 | 
				
			||||||
 | 
					        omega = torch.eye(self.mx, self.my)
 | 
				
			||||||
 | 
					        omegas = torch.stack([omega for _ in range(self.ny)], dim=0)
 | 
				
			||||||
 | 
					        actual = pt.distances.lomega_distance(self.x, self.y, omegas=omegas)
 | 
				
			||||||
 | 
					        desired = torch.empty(self.nx, self.ny)
 | 
				
			||||||
 | 
					        for i in range(self.nx):
 | 
				
			||||||
 | 
					            for j in range(self.ny):
 | 
				
			||||||
 | 
					                desired[i][j] = (torch.nn.functional.pairwise_distance(
 | 
				
			||||||
 | 
					                    self.x[i].reshape(1, -1),
 | 
				
			||||||
 | 
					                    self.y[j].reshape(1, -1),
 | 
				
			||||||
 | 
					                    p=2,
 | 
				
			||||||
 | 
					                    keepdim=False,
 | 
				
			||||||
 | 
					                )**2)
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=2)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        del self.x, self.y
 | 
				
			||||||
@@ -1,32 +1,97 @@
 | 
				
			|||||||
"""ProtoTorch datasets test suite."""
 | 
					"""ProtoTorch datasets test suite"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import shutil
 | 
					import shutil
 | 
				
			||||||
import unittest
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.datasets import abstract, tecator
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import Dataset, ProtoDataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestAbstract(unittest.TestCase):
 | 
					class TestAbstract(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.ds = Dataset("./artifacts")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_getitem(self):
 | 
					    def test_getitem(self):
 | 
				
			||||||
        with self.assertRaises(NotImplementedError):
 | 
					        with self.assertRaises(NotImplementedError):
 | 
				
			||||||
            abstract.Dataset("./artifacts")[0]
 | 
					            _ = self.ds[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_len(self):
 | 
					    def test_len(self):
 | 
				
			||||||
        with self.assertRaises(NotImplementedError):
 | 
					        with self.assertRaises(NotImplementedError):
 | 
				
			||||||
            len(abstract.Dataset("./artifacts"))
 | 
					            _ = len(self.ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        del self.ds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestProtoDataset(unittest.TestCase):
 | 
					class TestProtoDataset(unittest.TestCase):
 | 
				
			||||||
    def test_getitem(self):
 | 
					 | 
				
			||||||
        with self.assertRaises(NotImplementedError):
 | 
					 | 
				
			||||||
            abstract.ProtoDataset("./artifacts")[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_download(self):
 | 
					    def test_download(self):
 | 
				
			||||||
        with self.assertRaises(NotImplementedError):
 | 
					        with self.assertRaises(NotImplementedError):
 | 
				
			||||||
            abstract.ProtoDataset("./artifacts").download()
 | 
					            _ = ProtoDataset("./artifacts", download=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_exists(self):
 | 
				
			||||||
 | 
					        with self.assertRaises(RuntimeError):
 | 
				
			||||||
 | 
					            _ = ProtoDataset("./artifacts", download=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestNumpyDataset(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_list_init(self):
 | 
				
			||||||
 | 
					        ds = pt.datasets.NumpyDataset([1], [1])
 | 
				
			||||||
 | 
					        self.assertEqual(len(ds), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_numpy_init(self):
 | 
				
			||||||
 | 
					        data = np.random.randn(3, 2)
 | 
				
			||||||
 | 
					        targets = np.array([0, 1, 2])
 | 
				
			||||||
 | 
					        ds = pt.datasets.NumpyDataset(data, targets)
 | 
				
			||||||
 | 
					        self.assertEqual(len(ds), 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestSpiral(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_init(self):
 | 
				
			||||||
 | 
					        ds = pt.datasets.Spiral(num_samples=10)
 | 
				
			||||||
 | 
					        self.assertEqual(len(ds), 10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestIris(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.ds = pt.datasets.Iris()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_size(self):
 | 
				
			||||||
 | 
					        self.assertEqual(len(self.ds), 150)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_dims(self):
 | 
				
			||||||
 | 
					        self.assertEqual(self.ds.data.shape[1], 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_dims_selection(self):
 | 
				
			||||||
 | 
					        ds = pt.datasets.Iris(dims=[0, 1])
 | 
				
			||||||
 | 
					        self.assertEqual(ds.data.shape[1], 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestBlobs(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_size(self):
 | 
				
			||||||
 | 
					        ds = pt.datasets.Blobs(num_samples=10)
 | 
				
			||||||
 | 
					        self.assertEqual(len(ds), 10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestRandom(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_size(self):
 | 
				
			||||||
 | 
					        ds = pt.datasets.Random(num_samples=10)
 | 
				
			||||||
 | 
					        self.assertEqual(len(ds), 10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestCircles(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_size(self):
 | 
				
			||||||
 | 
					        ds = pt.datasets.Circles(num_samples=10)
 | 
				
			||||||
 | 
					        self.assertEqual(len(ds), 10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestMoons(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_size(self):
 | 
				
			||||||
 | 
					        ds = pt.datasets.Moons(num_samples=10)
 | 
				
			||||||
 | 
					        self.assertEqual(len(ds), 10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestTecator(unittest.TestCase):
 | 
					class TestTecator(unittest.TestCase):
 | 
				
			||||||
@@ -42,25 +107,25 @@ class TestTecator(unittest.TestCase):
 | 
				
			|||||||
        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
					        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
				
			||||||
        self._remove_artifacts()
 | 
					        self._remove_artifacts()
 | 
				
			||||||
        with self.assertRaises(RuntimeError):
 | 
					        with self.assertRaises(RuntimeError):
 | 
				
			||||||
            _ = tecator.Tecator(rootdir, download=False)
 | 
					            _ = pt.datasets.Tecator(rootdir, download=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_download_caching(self):
 | 
					    def test_download_caching(self):
 | 
				
			||||||
        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
					        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
				
			||||||
        _ = tecator.Tecator(rootdir, download=True, verbose=False)
 | 
					        _ = pt.datasets.Tecator(rootdir, download=True, verbose=False)
 | 
				
			||||||
        _ = tecator.Tecator(rootdir, download=False, verbose=False)
 | 
					        _ = pt.datasets.Tecator(rootdir, download=False, verbose=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_repr(self):
 | 
					    def test_repr(self):
 | 
				
			||||||
        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
					        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
				
			||||||
        train = tecator.Tecator(rootdir, download=True, verbose=True)
 | 
					        train = pt.datasets.Tecator(rootdir, download=True, verbose=True)
 | 
				
			||||||
        self.assertTrue("Split: Train" in train.__repr__())
 | 
					        self.assertTrue("Split: Train" in train.__repr__())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_download_train(self):
 | 
					    def test_download_train(self):
 | 
				
			||||||
        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
					        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
				
			||||||
        train = tecator.Tecator(root=rootdir,
 | 
					        train = pt.datasets.Tecator(root=rootdir,
 | 
				
			||||||
                                    train=True,
 | 
					                                    train=True,
 | 
				
			||||||
                                    download=True,
 | 
					                                    download=True,
 | 
				
			||||||
                                    verbose=False)
 | 
					                                    verbose=False)
 | 
				
			||||||
        train = tecator.Tecator(root=rootdir, download=True, verbose=False)
 | 
					        train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False)
 | 
				
			||||||
        x_train, y_train = train.data, train.targets
 | 
					        x_train, y_train = train.data, train.targets
 | 
				
			||||||
        self.assertEqual(x_train.shape[0], 144)
 | 
					        self.assertEqual(x_train.shape[0], 144)
 | 
				
			||||||
        self.assertEqual(y_train.shape[0], 144)
 | 
					        self.assertEqual(y_train.shape[0], 144)
 | 
				
			||||||
@@ -68,7 +133,7 @@ class TestTecator(unittest.TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def test_download_test(self):
 | 
					    def test_download_test(self):
 | 
				
			||||||
        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
					        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
				
			||||||
        test = tecator.Tecator(root=rootdir, train=False, verbose=False)
 | 
					        test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
 | 
				
			||||||
        x_test, y_test = test.data, test.targets
 | 
					        x_test, y_test = test.data, test.targets
 | 
				
			||||||
        self.assertEqual(x_test.shape[0], 71)
 | 
					        self.assertEqual(x_test.shape[0], 71)
 | 
				
			||||||
        self.assertEqual(y_test.shape[0], 71)
 | 
					        self.assertEqual(y_test.shape[0], 71)
 | 
				
			||||||
@@ -76,20 +141,20 @@ class TestTecator(unittest.TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def test_class_to_idx(self):
 | 
					    def test_class_to_idx(self):
 | 
				
			||||||
        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
					        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
				
			||||||
        test = tecator.Tecator(root=rootdir, train=False, verbose=False)
 | 
					        test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
 | 
				
			||||||
        _ = test.class_to_idx
 | 
					        _ = test.class_to_idx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_getitem(self):
 | 
					    def test_getitem(self):
 | 
				
			||||||
        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
					        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
				
			||||||
        test = tecator.Tecator(root=rootdir, train=False, verbose=False)
 | 
					        test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
 | 
				
			||||||
        x, y = test[0]
 | 
					        x, y = test[0]
 | 
				
			||||||
        self.assertEqual(x.shape[0], 100)
 | 
					        self.assertEqual(x.shape[0], 100)
 | 
				
			||||||
        self.assertIsInstance(y, int)
 | 
					        self.assertIsInstance(y, int)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_loadable_with_dataloader(self):
 | 
					    def test_loadable_with_dataloader(self):
 | 
				
			||||||
        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
					        rootdir = self.artifacts_dir.rpartition("/")[0]
 | 
				
			||||||
        test = tecator.Tecator(root=rootdir, train=False, verbose=False)
 | 
					        test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
 | 
				
			||||||
        _ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
 | 
					        _ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def tearDown(self):
 | 
					    def tearDown(self):
 | 
				
			||||||
        pass
 | 
					        self._remove_artifacts()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,581 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch functions test suite."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import unittest
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from prototorch.functions import (activations, competitions, distances,
 | 
					 | 
				
			||||||
                                  initializers, losses, pooling)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestActivations(unittest.TestCase):
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        self.flist = ["identity", "sigmoid_beta", "swish_beta"]
 | 
					 | 
				
			||||||
        self.x = torch.randn(1024, 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_registry(self):
 | 
					 | 
				
			||||||
        self.assertIsNotNone(activations.ACTIVATIONS)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_funcname_deserialization(self):
 | 
					 | 
				
			||||||
        for funcname in self.flist:
 | 
					 | 
				
			||||||
            f = activations.get_activation(funcname)
 | 
					 | 
				
			||||||
            iscallable = callable(f)
 | 
					 | 
				
			||||||
            self.assertTrue(iscallable)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # def test_torch_script(self):
 | 
					 | 
				
			||||||
    #     for funcname in self.flist:
 | 
					 | 
				
			||||||
    #         f = activations.get_activation(funcname)
 | 
					 | 
				
			||||||
    #         self.assertIsInstance(f, torch.jit.ScriptFunction)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_callable_deserialization(self):
 | 
					 | 
				
			||||||
        def dummy(x, **kwargs):
 | 
					 | 
				
			||||||
            return x
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for f in [dummy, lambda x: x]:
 | 
					 | 
				
			||||||
            f = activations.get_activation(f)
 | 
					 | 
				
			||||||
            iscallable = callable(f)
 | 
					 | 
				
			||||||
            self.assertTrue(iscallable)
 | 
					 | 
				
			||||||
            self.assertEqual(1, f(1))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_unknown_deserialization(self):
 | 
					 | 
				
			||||||
        for funcname in ["blubb", "foobar"]:
 | 
					 | 
				
			||||||
            with self.assertRaises(NameError):
 | 
					 | 
				
			||||||
                _ = activations.get_activation(funcname)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_identity(self):
 | 
					 | 
				
			||||||
        actual = activations.identity(self.x)
 | 
					 | 
				
			||||||
        desired = self.x
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_sigmoid_beta1(self):
 | 
					 | 
				
			||||||
        actual = activations.sigmoid_beta(self.x, beta=1.0)
 | 
					 | 
				
			||||||
        desired = torch.sigmoid(self.x)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_swish_beta1(self):
 | 
					 | 
				
			||||||
        actual = activations.swish_beta(self.x, beta=1.0)
 | 
					 | 
				
			||||||
        desired = self.x * torch.sigmoid(self.x)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def tearDown(self):
 | 
					 | 
				
			||||||
        del self.x
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestCompetitions(unittest.TestCase):
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_wtac(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 1, 2, 3])
 | 
					 | 
				
			||||||
        actual = competitions.wtac(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([2, 0])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_wtac_unequal_dist(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 1, 1])
 | 
					 | 
				
			||||||
        actual = competitions.wtac(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([0, 1])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_wtac_one_hot(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([[0, 1], [1, 0]])
 | 
					 | 
				
			||||||
        actual = competitions.wtac(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[0, 1], [1, 0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_knnc_k1(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 1, 2, 3])
 | 
					 | 
				
			||||||
        actual = competitions.knnc(d, labels, k=1)
 | 
					 | 
				
			||||||
        desired = torch.tensor([2, 0])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def tearDown(self):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestPooling(unittest.TestCase):
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_min(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 0, 1, 2])
 | 
					 | 
				
			||||||
        actual = pooling.stratified_min_pooling(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_min_one_hot(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 0, 1, 2])
 | 
					 | 
				
			||||||
        labels = torch.eye(3)[labels]
 | 
					 | 
				
			||||||
        actual = pooling.stratified_min_pooling(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_min_trivial(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 1, 2])
 | 
					 | 
				
			||||||
        actual = pooling.stratified_min_pooling(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_max(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 0, 3, 2, 0])
 | 
					 | 
				
			||||||
        actual = pooling.stratified_max_pooling(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_max_one_hot(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 0, 2, 1, 0])
 | 
					 | 
				
			||||||
        labels = torch.nn.functional.one_hot(labels, num_classes=3)
 | 
					 | 
				
			||||||
        actual = pooling.stratified_max_pooling(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_sum(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
 | 
					 | 
				
			||||||
        labels = torch.LongTensor([0, 0, 1, 2])
 | 
					 | 
				
			||||||
        actual = pooling.stratified_sum_pooling(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_sum_one_hot(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 0, 1, 2])
 | 
					 | 
				
			||||||
        labels = torch.eye(3)[labels]
 | 
					 | 
				
			||||||
        actual = pooling.stratified_sum_pooling(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_prod(self):
 | 
					 | 
				
			||||||
        d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 0, 3, 2, 0])
 | 
					 | 
				
			||||||
        actual = pooling.stratified_prod_pooling(d, labels)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def tearDown(self):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestDistances(unittest.TestCase):
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        self.nx, self.mx = 32, 2048
 | 
					 | 
				
			||||||
        self.ny, self.my = 8, 2048
 | 
					 | 
				
			||||||
        self.x = torch.randn(self.nx, self.mx)
 | 
					 | 
				
			||||||
        self.y = torch.randn(self.ny, self.my)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_manhattan(self):
 | 
					 | 
				
			||||||
        actual = distances.lpnorm_distance(self.x, self.y, p=1)
 | 
					 | 
				
			||||||
        desired = torch.empty(self.nx, self.ny)
 | 
					 | 
				
			||||||
        for i in range(self.nx):
 | 
					 | 
				
			||||||
            for j in range(self.ny):
 | 
					 | 
				
			||||||
                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
					 | 
				
			||||||
                    self.x[i].reshape(1, -1),
 | 
					 | 
				
			||||||
                    self.y[j].reshape(1, -1),
 | 
					 | 
				
			||||||
                    p=1,
 | 
					 | 
				
			||||||
                    keepdim=False,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=2)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_euclidean(self):
 | 
					 | 
				
			||||||
        actual = distances.euclidean_distance(self.x, self.y)
 | 
					 | 
				
			||||||
        desired = torch.empty(self.nx, self.ny)
 | 
					 | 
				
			||||||
        for i in range(self.nx):
 | 
					 | 
				
			||||||
            for j in range(self.ny):
 | 
					 | 
				
			||||||
                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
					 | 
				
			||||||
                    self.x[i].reshape(1, -1),
 | 
					 | 
				
			||||||
                    self.y[j].reshape(1, -1),
 | 
					 | 
				
			||||||
                    p=2,
 | 
					 | 
				
			||||||
                    keepdim=False,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=3)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_squared_euclidean(self):
 | 
					 | 
				
			||||||
        actual = distances.squared_euclidean_distance(self.x, self.y)
 | 
					 | 
				
			||||||
        desired = torch.empty(self.nx, self.ny)
 | 
					 | 
				
			||||||
        for i in range(self.nx):
 | 
					 | 
				
			||||||
            for j in range(self.ny):
 | 
					 | 
				
			||||||
                desired[i][j] = (torch.nn.functional.pairwise_distance(
 | 
					 | 
				
			||||||
                    self.x[i].reshape(1, -1),
 | 
					 | 
				
			||||||
                    self.y[j].reshape(1, -1),
 | 
					 | 
				
			||||||
                    p=2,
 | 
					 | 
				
			||||||
                    keepdim=False,
 | 
					 | 
				
			||||||
                )**2)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=2)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_lpnorm_p0(self):
 | 
					 | 
				
			||||||
        actual = distances.lpnorm_distance(self.x, self.y, p=0)
 | 
					 | 
				
			||||||
        desired = torch.empty(self.nx, self.ny)
 | 
					 | 
				
			||||||
        for i in range(self.nx):
 | 
					 | 
				
			||||||
            for j in range(self.ny):
 | 
					 | 
				
			||||||
                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
					 | 
				
			||||||
                    self.x[i].reshape(1, -1),
 | 
					 | 
				
			||||||
                    self.y[j].reshape(1, -1),
 | 
					 | 
				
			||||||
                    p=0,
 | 
					 | 
				
			||||||
                    keepdim=False,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=4)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_lpnorm_p2(self):
 | 
					 | 
				
			||||||
        actual = distances.lpnorm_distance(self.x, self.y, p=2)
 | 
					 | 
				
			||||||
        desired = torch.empty(self.nx, self.ny)
 | 
					 | 
				
			||||||
        for i in range(self.nx):
 | 
					 | 
				
			||||||
            for j in range(self.ny):
 | 
					 | 
				
			||||||
                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
					 | 
				
			||||||
                    self.x[i].reshape(1, -1),
 | 
					 | 
				
			||||||
                    self.y[j].reshape(1, -1),
 | 
					 | 
				
			||||||
                    p=2,
 | 
					 | 
				
			||||||
                    keepdim=False,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=4)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_lpnorm_p3(self):
 | 
					 | 
				
			||||||
        actual = distances.lpnorm_distance(self.x, self.y, p=3)
 | 
					 | 
				
			||||||
        desired = torch.empty(self.nx, self.ny)
 | 
					 | 
				
			||||||
        for i in range(self.nx):
 | 
					 | 
				
			||||||
            for j in range(self.ny):
 | 
					 | 
				
			||||||
                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
					 | 
				
			||||||
                    self.x[i].reshape(1, -1),
 | 
					 | 
				
			||||||
                    self.y[j].reshape(1, -1),
 | 
					 | 
				
			||||||
                    p=3,
 | 
					 | 
				
			||||||
                    keepdim=False,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=4)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_lpnorm_pinf(self):
 | 
					 | 
				
			||||||
        actual = distances.lpnorm_distance(self.x, self.y, p=float("inf"))
 | 
					 | 
				
			||||||
        desired = torch.empty(self.nx, self.ny)
 | 
					 | 
				
			||||||
        for i in range(self.nx):
 | 
					 | 
				
			||||||
            for j in range(self.ny):
 | 
					 | 
				
			||||||
                desired[i][j] = torch.nn.functional.pairwise_distance(
 | 
					 | 
				
			||||||
                    self.x[i].reshape(1, -1),
 | 
					 | 
				
			||||||
                    self.y[j].reshape(1, -1),
 | 
					 | 
				
			||||||
                    p=float("inf"),
 | 
					 | 
				
			||||||
                    keepdim=False,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=4)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_omega_identity(self):
 | 
					 | 
				
			||||||
        omega = torch.eye(self.mx, self.my)
 | 
					 | 
				
			||||||
        actual = distances.omega_distance(self.x, self.y, omega=omega)
 | 
					 | 
				
			||||||
        desired = torch.empty(self.nx, self.ny)
 | 
					 | 
				
			||||||
        for i in range(self.nx):
 | 
					 | 
				
			||||||
            for j in range(self.ny):
 | 
					 | 
				
			||||||
                desired[i][j] = (torch.nn.functional.pairwise_distance(
 | 
					 | 
				
			||||||
                    self.x[i].reshape(1, -1),
 | 
					 | 
				
			||||||
                    self.y[j].reshape(1, -1),
 | 
					 | 
				
			||||||
                    p=2,
 | 
					 | 
				
			||||||
                    keepdim=False,
 | 
					 | 
				
			||||||
                )**2)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=2)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_lomega_identity(self):
 | 
					 | 
				
			||||||
        omega = torch.eye(self.mx, self.my)
 | 
					 | 
				
			||||||
        omegas = torch.stack([omega for _ in range(self.ny)], dim=0)
 | 
					 | 
				
			||||||
        actual = distances.lomega_distance(self.x, self.y, omegas=omegas)
 | 
					 | 
				
			||||||
        desired = torch.empty(self.nx, self.ny)
 | 
					 | 
				
			||||||
        for i in range(self.nx):
 | 
					 | 
				
			||||||
            for j in range(self.ny):
 | 
					 | 
				
			||||||
                desired[i][j] = (torch.nn.functional.pairwise_distance(
 | 
					 | 
				
			||||||
                    self.x[i].reshape(1, -1),
 | 
					 | 
				
			||||||
                    self.y[j].reshape(1, -1),
 | 
					 | 
				
			||||||
                    p=2,
 | 
					 | 
				
			||||||
                    keepdim=False,
 | 
					 | 
				
			||||||
                )**2)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=2)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def tearDown(self):
 | 
					 | 
				
			||||||
        del self.x, self.y
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestInitializers(unittest.TestCase):
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        self.flist = [
 | 
					 | 
				
			||||||
            "zeros",
 | 
					 | 
				
			||||||
            "ones",
 | 
					 | 
				
			||||||
            "rand",
 | 
					 | 
				
			||||||
            "randn",
 | 
					 | 
				
			||||||
            "stratified_mean",
 | 
					 | 
				
			||||||
            "stratified_random",
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
        self.x = torch.tensor(
 | 
					 | 
				
			||||||
            [[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
 | 
					 | 
				
			||||||
            dtype=torch.float32)
 | 
					 | 
				
			||||||
        self.y = torch.tensor([0, 0, 1, 1])
 | 
					 | 
				
			||||||
        self.gen = torch.manual_seed(42)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_registry(self):
 | 
					 | 
				
			||||||
        self.assertIsNotNone(initializers.INITIALIZERS)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_funcname_deserialization(self):
 | 
					 | 
				
			||||||
        for funcname in self.flist:
 | 
					 | 
				
			||||||
            f = initializers.get_initializer(funcname)
 | 
					 | 
				
			||||||
            iscallable = callable(f)
 | 
					 | 
				
			||||||
            self.assertTrue(iscallable)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_callable_deserialization(self):
 | 
					 | 
				
			||||||
        def dummy(x):
 | 
					 | 
				
			||||||
            return x
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for f in [dummy, lambda x: x]:
 | 
					 | 
				
			||||||
            f = initializers.get_initializer(f)
 | 
					 | 
				
			||||||
            iscallable = callable(f)
 | 
					 | 
				
			||||||
            self.assertTrue(iscallable)
 | 
					 | 
				
			||||||
            self.assertEqual(1, f(1))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_unknown_deserialization(self):
 | 
					 | 
				
			||||||
        for funcname in ["blubb", "foobar"]:
 | 
					 | 
				
			||||||
            with self.assertRaises(NameError):
 | 
					 | 
				
			||||||
                _ = initializers.get_initializer(funcname)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_zeros(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([1, 1])
 | 
					 | 
				
			||||||
        actual, _ = initializers.zeros(self.x, self.y, pdist)
 | 
					 | 
				
			||||||
        desired = torch.zeros(2, 3)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_ones(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([1, 1])
 | 
					 | 
				
			||||||
        actual, _ = initializers.ones(self.x, self.y, pdist)
 | 
					 | 
				
			||||||
        desired = torch.ones(2, 3)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_rand(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([1, 1])
 | 
					 | 
				
			||||||
        actual, _ = initializers.rand(self.x, self.y, pdist)
 | 
					 | 
				
			||||||
        desired = torch.rand(2, 3, generator=torch.manual_seed(42))
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_randn(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([1, 1])
 | 
					 | 
				
			||||||
        actual, _ = initializers.randn(self.x, self.y, pdist)
 | 
					 | 
				
			||||||
        desired = torch.randn(2, 3, generator=torch.manual_seed(42))
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_mean_equal1(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([1, 1])
 | 
					 | 
				
			||||||
        actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_random_equal1(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([1, 1])
 | 
					 | 
				
			||||||
        actual, _ = initializers.stratified_random(self.x, self.y, pdist,
 | 
					 | 
				
			||||||
                                                   False)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_mean_equal2(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([2, 2])
 | 
					 | 
				
			||||||
        actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[5.0, 5.0, 5.0], [5.0, 5.0, 5.0],
 | 
					 | 
				
			||||||
                                [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_random_equal2(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([2, 2])
 | 
					 | 
				
			||||||
        actual, _ = initializers.stratified_random(self.x, self.y, pdist,
 | 
					 | 
				
			||||||
                                                   False)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, -1.0, -2.0],
 | 
					 | 
				
			||||||
                                [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_mean_unequal(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([1, 3])
 | 
					 | 
				
			||||||
        actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
 | 
					 | 
				
			||||||
                                [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_random_unequal(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([1, 3])
 | 
					 | 
				
			||||||
        actual, _ = initializers.stratified_random(self.x, self.y, pdist,
 | 
					 | 
				
			||||||
                                                   False)
 | 
					 | 
				
			||||||
        desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
 | 
					 | 
				
			||||||
                                [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
					 | 
				
			||||||
                                                        desired,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_mean_unequal_one_hot(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([1, 3])
 | 
					 | 
				
			||||||
        y = torch.eye(2)[self.y]
 | 
					 | 
				
			||||||
        desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
 | 
					 | 
				
			||||||
                                 [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
 | 
					 | 
				
			||||||
        actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
 | 
					 | 
				
			||||||
        desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual1,
 | 
					 | 
				
			||||||
                                                        desired1,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual2,
 | 
					 | 
				
			||||||
                                                        desired2,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_stratified_random_unequal_one_hot(self):
 | 
					 | 
				
			||||||
        pdist = torch.tensor([1, 3])
 | 
					 | 
				
			||||||
        y = torch.eye(2)[self.y]
 | 
					 | 
				
			||||||
        actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
 | 
					 | 
				
			||||||
        desired1 = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
 | 
					 | 
				
			||||||
                                 [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
 | 
					 | 
				
			||||||
        desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual1,
 | 
					 | 
				
			||||||
                                                        desired1,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        mismatch = np.testing.assert_array_almost_equal(actual2,
 | 
					 | 
				
			||||||
                                                        desired2,
 | 
					 | 
				
			||||||
                                                        decimal=5)
 | 
					 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def tearDown(self):
 | 
					 | 
				
			||||||
        del self.x, self.y, self.gen
 | 
					 | 
				
			||||||
        _ = torch.seed()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestLosses(unittest.TestCase):
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_glvq_loss_int_labels(self):
 | 
					 | 
				
			||||||
        d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
 | 
					 | 
				
			||||||
        labels = torch.tensor([0, 1])
 | 
					 | 
				
			||||||
        targets = torch.ones(100)
 | 
					 | 
				
			||||||
        batch_loss = losses.glvq_loss(distances=d,
 | 
					 | 
				
			||||||
                                      target_labels=targets,
 | 
					 | 
				
			||||||
                                      prototype_labels=labels)
 | 
					 | 
				
			||||||
        loss_value = torch.sum(batch_loss, dim=0)
 | 
					 | 
				
			||||||
        self.assertEqual(loss_value, -100)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_glvq_loss_one_hot_labels(self):
 | 
					 | 
				
			||||||
        d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
 | 
					 | 
				
			||||||
        labels = torch.tensor([[0, 1], [1, 0]])
 | 
					 | 
				
			||||||
        wl = torch.tensor([1, 0])
 | 
					 | 
				
			||||||
        targets = torch.stack([wl for _ in range(100)], dim=0)
 | 
					 | 
				
			||||||
        batch_loss = losses.glvq_loss(distances=d,
 | 
					 | 
				
			||||||
                                      target_labels=targets,
 | 
					 | 
				
			||||||
                                      prototype_labels=labels)
 | 
					 | 
				
			||||||
        loss_value = torch.sum(batch_loss, dim=0)
 | 
					 | 
				
			||||||
        self.assertEqual(loss_value, -100)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_glvq_loss_one_hot_unequal(self):
 | 
					 | 
				
			||||||
        dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
 | 
					 | 
				
			||||||
        d = torch.stack(dlist, dim=1)
 | 
					 | 
				
			||||||
        labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
 | 
					 | 
				
			||||||
        wl = torch.tensor([1, 0])
 | 
					 | 
				
			||||||
        targets = torch.stack([wl for _ in range(100)], dim=0)
 | 
					 | 
				
			||||||
        batch_loss = losses.glvq_loss(distances=d,
 | 
					 | 
				
			||||||
                                      target_labels=targets,
 | 
					 | 
				
			||||||
                                      prototype_labels=labels)
 | 
					 | 
				
			||||||
        loss_value = torch.sum(batch_loss, dim=0)
 | 
					 | 
				
			||||||
        self.assertEqual(loss_value, -100)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def tearDown(self):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
							
								
								
									
										47
									
								
								tests/test_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								tests/test_utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,47 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch utils test suite"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_mesh2d_without_input():
 | 
				
			||||||
 | 
					    mesh, xx, yy = pt.utils.mesh2d(border=2.0, resolution=10)
 | 
				
			||||||
 | 
					    assert mesh.shape[0] == 100
 | 
				
			||||||
 | 
					    assert mesh.shape[1] == 2
 | 
				
			||||||
 | 
					    assert xx.shape[0] == 10
 | 
				
			||||||
 | 
					    assert xx.shape[1] == 10
 | 
				
			||||||
 | 
					    assert yy.shape[0] == 10
 | 
				
			||||||
 | 
					    assert yy.shape[1] == 10
 | 
				
			||||||
 | 
					    assert np.min(xx) == -2.0
 | 
				
			||||||
 | 
					    assert np.max(xx) == 2.0
 | 
				
			||||||
 | 
					    assert np.min(yy) == -2.0
 | 
				
			||||||
 | 
					    assert np.max(yy) == 2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_mesh2d_with_torch_input():
 | 
				
			||||||
 | 
					    x = 10 * torch.rand(5, 2)
 | 
				
			||||||
 | 
					    mesh, xx, yy = pt.utils.mesh2d(x, border=0.0, resolution=100)
 | 
				
			||||||
 | 
					    assert mesh.shape[0] == 100 * 100
 | 
				
			||||||
 | 
					    assert mesh.shape[1] == 2
 | 
				
			||||||
 | 
					    assert xx.shape[0] == 100
 | 
				
			||||||
 | 
					    assert xx.shape[1] == 100
 | 
				
			||||||
 | 
					    assert yy.shape[0] == 100
 | 
				
			||||||
 | 
					    assert yy.shape[1] == 100
 | 
				
			||||||
 | 
					    assert np.min(xx) == x[:, 0].min()
 | 
				
			||||||
 | 
					    assert np.max(xx) == x[:, 0].max()
 | 
				
			||||||
 | 
					    assert np.min(yy) == x[:, 1].min()
 | 
				
			||||||
 | 
					    assert np.max(yy) == x[:, 1].max()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_hex_to_rgb():
 | 
				
			||||||
 | 
					    red_rgb = list(pt.utils.hex_to_rgb(["#ff0000"]))[0]
 | 
				
			||||||
 | 
					    assert red_rgb[0] == 255
 | 
				
			||||||
 | 
					    assert red_rgb[1] == 0
 | 
				
			||||||
 | 
					    assert red_rgb[2] == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_rgb_to_hex():
 | 
				
			||||||
 | 
					    blue_hex = list(pt.utils.rgb_to_hex([(0, 0, 255)]))[0]
 | 
				
			||||||
 | 
					    assert blue_hex.lower() == "0000ff"
 | 
				
			||||||
		Reference in New Issue
	
	Block a user