Compare commits
106 Commits
v0.3.0-dev
...
feature/tr
Author | SHA1 | Date | |
---|---|---|---|
|
17b45249f4 | ||
|
4f1c879528 | ||
|
2272c55092 | ||
|
b03c9b1d3c | ||
|
0c28eda706 | ||
|
7bc0bfa3ab | ||
|
827958a28a | ||
|
8200e1d3d8 | ||
|
729b20e9ab | ||
|
ca8ac7a43b | ||
|
b724a28a6f | ||
|
1e0a8392a2 | ||
|
2eb7b05653 | ||
|
d8a0b2dfcc | ||
|
2a7394b593 | ||
|
b1e64c8b8b | ||
|
70cf17607e | ||
|
b1568a550a | ||
|
e8e803e8ef | ||
|
2c453265fe | ||
|
7336d35fee | ||
|
bc18952c05 | ||
|
8e8d0b9c2c | ||
|
5a7da2b40b | ||
|
b6d38f442b | ||
|
8e8851d962 | ||
|
27b43b06a7 | ||
|
ff69eb1256 | ||
|
4ca581909a | ||
|
2722d976f5 | ||
|
946cda00d2 | ||
|
8227525c82 | ||
|
e61ae73749 | ||
|
040d1ee9e8 | ||
|
7f0da894fa | ||
|
62726df278 | ||
|
0ba09db6fe | ||
|
87334c11e6 | ||
|
40ef3aeda2 | ||
|
94fe4435a8 | ||
|
c204bc8e1f | ||
|
00615ae837 | ||
|
9f5f0d12dd | ||
|
8a291f7bfb | ||
|
21e3e3b82d | ||
|
a6bd6e130a | ||
|
fcdfa52892 | ||
|
73e6fe384e | ||
|
aff7a385a3 | ||
|
1e23ba05fa | ||
|
ee30d4da5b | ||
|
14508f0600 | ||
|
e3f8828da4 | ||
|
30adbf705c | ||
|
ee42fd68b1 | ||
|
736d9a6349 | ||
|
0055e15bc1 | ||
|
b2e1df7308 | ||
|
b935e9caf3 | ||
|
503ef0e05f | ||
|
dc6248413c | ||
|
e73b70ceb7 | ||
|
639198e774 | ||
|
768d969f89 | ||
|
aec422c277 | ||
|
6c14170de6 | ||
|
36a330aa66 | ||
|
acd4ac6a86 | ||
|
abe64cfe8f | ||
|
caae95d01d | ||
|
088429a16a | ||
|
b6145223c8 | ||
|
09256956f3 | ||
|
0ca90fdcee | ||
|
be21412f8a | ||
|
ae6bc47f87 | ||
|
7bb93f027a | ||
|
bc20acd63b | ||
|
a864cf5d4d | ||
|
2175f524e8 | ||
|
c1c21e92df | ||
|
2b676ee06e | ||
|
dda2f1d779 | ||
|
3a8388e24f | ||
|
a9eef8ae6d | ||
|
ac3091d8da | ||
|
ce3991de94 | ||
|
47b4b9bcb1 | ||
|
19475d7e2b | ||
|
269eb8ba25 | ||
|
b06ded683d | ||
|
466e9bde6b | ||
|
fc7d64aaea | ||
|
9a7d3192c0 | ||
|
e686adbea1 | ||
|
b7d53aa5f1 | ||
|
9b663477fd | ||
|
a70166280a | ||
|
a083c4b276 | ||
|
40751aa50a | ||
|
7c30ffe2c7 | ||
|
e1d56595c1 | ||
|
4540c8848e | ||
|
c88f288d12 | ||
|
e2918dffed | ||
|
7d9dfc27ee |
@@ -1,20 +1,11 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.3.0-dev0
|
current_version = 0.5.0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
serialize =
|
serialize =
|
||||||
{major}.{minor}.{patch}-{release}{build}
|
|
||||||
{major}.{minor}.{patch}
|
{major}.{minor}.{patch}
|
||||||
|
|
||||||
[bumpversion:part:release]
|
|
||||||
optional_value = prod
|
|
||||||
first_value = dev
|
|
||||||
values =
|
|
||||||
dev
|
|
||||||
rc
|
|
||||||
prod
|
|
||||||
|
|
||||||
[bumpversion:file:setup.py]
|
[bumpversion:file:setup.py]
|
||||||
|
|
||||||
[bumpversion:file:./prototorch/__init__.py]
|
[bumpversion:file:./prototorch/__init__.py]
|
||||||
|
31
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
31
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
---
|
||||||
|
name: Bug report
|
||||||
|
about: Create a report to help us improve
|
||||||
|
title: ''
|
||||||
|
labels: ''
|
||||||
|
assignees: ''
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Describe the bug**
|
||||||
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
|
**To Reproduce**
|
||||||
|
Steps to reproduce the behavior:
|
||||||
|
1. Install Prototorch by running '...'
|
||||||
|
2. Run script '...'
|
||||||
|
3. See errors
|
||||||
|
|
||||||
|
**Expected behavior**
|
||||||
|
A clear and concise description of what you expected to happen.
|
||||||
|
|
||||||
|
**Screenshots**
|
||||||
|
If applicable, add screenshots to help explain your problem.
|
||||||
|
|
||||||
|
**Desktop (please complete the following information):**
|
||||||
|
- OS: [e.g. Ubuntu 20.10]
|
||||||
|
- Prototorch Version: [e.g. v0.4.0]
|
||||||
|
- Python Version: [e.g. 3.9.5]
|
||||||
|
|
||||||
|
**Additional context**
|
||||||
|
Add any other context about the problem here.
|
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
---
|
||||||
|
name: Feature request
|
||||||
|
about: Suggest an idea for this project
|
||||||
|
title: ''
|
||||||
|
labels: ''
|
||||||
|
assignees: ''
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Is your feature request related to a problem? Please describe.**
|
||||||
|
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||||
|
|
||||||
|
**Describe the solution you'd like**
|
||||||
|
A clear and concise description of what you want to happen.
|
||||||
|
|
||||||
|
**Describe alternatives you've considered**
|
||||||
|
A clear and concise description of any alternative solutions or features you've considered.
|
||||||
|
|
||||||
|
**Additional context**
|
||||||
|
Add any other context or screenshots about the feature request here.
|
5
.github/workflows/pythonapp.yml
vendored
5
.github/workflows/pythonapp.yml
vendored
@@ -23,10 +23,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .
|
pip install .[all]
|
||||||
- name: Install extras
|
|
||||||
run: |
|
|
||||||
pip install -r requirements.txt
|
|
||||||
- name: Lint with flake8
|
- name: Lint with flake8
|
||||||
run: |
|
run: |
|
||||||
pip install flake8
|
pip install flake8
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -155,3 +155,4 @@ scratch*
|
|||||||
.vscode/
|
.vscode/
|
||||||
|
|
||||||
reports
|
reports
|
||||||
|
artifacts
|
10
.travis.yml
10
.travis.yml
@@ -4,11 +4,11 @@ language: python
|
|||||||
python: 3.8
|
python: 3.8
|
||||||
cache:
|
cache:
|
||||||
directories:
|
directories:
|
||||||
|
- "$HOME/.cache/pip"
|
||||||
- "./tests/artifacts"
|
- "./tests/artifacts"
|
||||||
# - "$HOME/.prototorch/datasets"
|
- "$HOME/datasets"
|
||||||
install:
|
install:
|
||||||
- pip install . --progress-bar off
|
- pip install .[all] --progress-bar off
|
||||||
- pip install -r requirements.txt
|
|
||||||
|
|
||||||
# Generate code coverage report
|
# Generate code coverage report
|
||||||
script:
|
script:
|
||||||
@@ -25,8 +25,8 @@ deploy:
|
|||||||
password:
|
password:
|
||||||
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
||||||
on:
|
on:
|
||||||
tags: true
|
tags: true
|
||||||
skip_existing: true
|
skip_existing: true
|
||||||
|
|
||||||
# The password is encrypted with:
|
# The password is encrypted with:
|
||||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
||||||
|
@@ -31,15 +31,15 @@ To also install the extras, use
|
|||||||
pip install -U prototorch[all]
|
pip install -U prototorch[all]
|
||||||
```
|
```
|
||||||
|
|
||||||
*Note: If you're using [ZSH](https://www.zsh.org/), the square brackets `[ ]`
|
*Note: If you're using [ZSH](https://www.zsh.org/) (which is also the default
|
||||||
have to be escaped like so: `\[\]`, making the install command `pip install -U
|
shell on MacOS now), the square brackets `[ ]` have to be escaped like so:
|
||||||
prototorch\[all\]`.*
|
`\[\]`, making the install command `pip install -U prototorch\[all\]`.*
|
||||||
|
|
||||||
To install the bleeding-edge features and improvements:
|
To install the bleeding-edge features and improvements:
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/si-cim/prototorch.git
|
git clone https://github.com/si-cim/prototorch.git
|
||||||
git checkout dev
|
|
||||||
cd prototorch
|
cd prototorch
|
||||||
|
git checkout dev
|
||||||
pip install -e .[all]
|
pip install -e .[all]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@@ -1,13 +1,16 @@
|
|||||||
# ProtoTorch Releases
|
# ProtoTorch Releases
|
||||||
|
|
||||||
|
## Release 0.5.0
|
||||||
|
|
||||||
|
- Breaking: Removed deprecated `prototorch.modules.Prototypes1D`.
|
||||||
|
- Use `prototorch.components.LabeledComponents` instead.
|
||||||
|
|
||||||
## Release 0.2.0
|
## Release 0.2.0
|
||||||
|
|
||||||
### Includes
|
|
||||||
- Fixes in example scripts.
|
- Fixes in example scripts.
|
||||||
|
|
||||||
## Release 0.1.1-dev0
|
## Release 0.1.1-dev0
|
||||||
|
|
||||||
### Includes
|
|
||||||
- Minor bugfixes.
|
- Minor bugfixes.
|
||||||
- 100% line coverage.
|
- 100% line coverage.
|
||||||
|
|
||||||
|
@@ -1,13 +1,24 @@
|
|||||||
.. ProtoFlow API Reference
|
.. ProtoTorch API Reference
|
||||||
|
|
||||||
ProtoFlow API Reference
|
ProtoTorch API Reference
|
||||||
======================================
|
======================================
|
||||||
|
|
||||||
Datasets
|
Datasets
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
|
|
||||||
|
Common Datasets
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
.. automodule:: prototorch.datasets
|
.. automodule:: prototorch.datasets
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
|
||||||
|
|
||||||
|
Abstract Datasets
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Abstract Datasets are used to build your own datasets.
|
||||||
|
|
||||||
|
.. autoclass:: prototorch.datasets.abstract.NumpyDataset
|
||||||
|
:members:
|
||||||
|
|
||||||
Functions
|
Functions
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
|
@@ -12,9 +12,8 @@
|
|||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
sys.path.insert(0, os.path.abspath("../../"))
|
|
||||||
|
|
||||||
import sphinx_rtd_theme
|
sys.path.insert(0, os.path.abspath("../../"))
|
||||||
|
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
@@ -24,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.3.0-dev0"
|
release = "0.5.0"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
@@ -47,6 +46,7 @@ extensions = [
|
|||||||
"sphinx.ext.viewcode",
|
"sphinx.ext.viewcode",
|
||||||
"sphinx_rtd_theme",
|
"sphinx_rtd_theme",
|
||||||
"sphinxcontrib.katex",
|
"sphinxcontrib.katex",
|
||||||
|
'sphinx_autodoc_typehints',
|
||||||
]
|
]
|
||||||
|
|
||||||
# katex_prerender = True
|
# katex_prerender = True
|
||||||
@@ -128,15 +128,12 @@ latex_elements = {
|
|||||||
# The paper size ("letterpaper" or "a4paper").
|
# The paper size ("letterpaper" or "a4paper").
|
||||||
#
|
#
|
||||||
# "papersize": "letterpaper",
|
# "papersize": "letterpaper",
|
||||||
|
|
||||||
# The font size ("10pt", "11pt" or "12pt").
|
# The font size ("10pt", "11pt" or "12pt").
|
||||||
#
|
#
|
||||||
# "pointsize": "10pt",
|
# "pointsize": "10pt",
|
||||||
|
|
||||||
# Additional stuff for the LaTeX preamble.
|
# Additional stuff for the LaTeX preamble.
|
||||||
#
|
#
|
||||||
# "preamble": "",
|
# "preamble": "",
|
||||||
|
|
||||||
# Latex figure (float) alignment
|
# Latex figure (float) alignment
|
||||||
#
|
#
|
||||||
# "figure_align": "htbp",
|
# "figure_align": "htbp",
|
||||||
@@ -146,15 +143,21 @@ latex_elements = {
|
|||||||
# (source start file, target name, title,
|
# (source start file, target name, title,
|
||||||
# author, documentclass [howto, manual, or own class]).
|
# author, documentclass [howto, manual, or own class]).
|
||||||
latex_documents = [
|
latex_documents = [
|
||||||
(master_doc, "prototorch.tex", "ProtoTorch Documentation",
|
(
|
||||||
"Jensun Ravichandran", "manual"),
|
master_doc,
|
||||||
|
"prototorch.tex",
|
||||||
|
"ProtoTorch Documentation",
|
||||||
|
"Jensun Ravichandran",
|
||||||
|
"manual",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# -- Options for manual page output ---------------------------------------
|
# -- Options for manual page output ---------------------------------------
|
||||||
|
|
||||||
# One entry per manual page. List of tuples
|
# One entry per manual page. List of tuples
|
||||||
# (source start file, name, description, authors, manual section).
|
# (source start file, name, description, authors, manual section).
|
||||||
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)]
|
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author],
|
||||||
|
1)]
|
||||||
|
|
||||||
# -- Options for Texinfo output -------------------------------------------
|
# -- Options for Texinfo output -------------------------------------------
|
||||||
|
|
||||||
@@ -162,15 +165,24 @@ man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)
|
|||||||
# (source start file, target name, title, author,
|
# (source start file, target name, title, author,
|
||||||
# dir menu entry, description, category)
|
# dir menu entry, description, category)
|
||||||
texinfo_documents = [
|
texinfo_documents = [
|
||||||
(master_doc, "prototorch", "ProtoTorch Documentation", author, "prototorch",
|
(
|
||||||
"Prototype-based machine learning in PyTorch.",
|
master_doc,
|
||||||
"Miscellaneous"),
|
"prototorch",
|
||||||
|
"ProtoTorch Documentation",
|
||||||
|
author,
|
||||||
|
"prototorch",
|
||||||
|
"Prototype-based machine learning in PyTorch.",
|
||||||
|
"Miscellaneous",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Example configuration for intersphinx: refer to the Python standard library.
|
# Example configuration for intersphinx: refer to the Python standard library.
|
||||||
intersphinx_mapping = {
|
intersphinx_mapping = {
|
||||||
"python": ("https://docs.python.org/", None),
|
"python": ("https://docs.python.org/", None),
|
||||||
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
|
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
|
||||||
|
"torch": ('https://pytorch.org/docs/stable/', None),
|
||||||
|
"pytorch_lightning":
|
||||||
|
("https://pytorch-lightning.readthedocs.io/en/stable/", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
# -- Options for Epub output ----------------------------------------------
|
# -- Options for Epub output ----------------------------------------------
|
||||||
|
@@ -3,10 +3,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from torchinfo import summary
|
from torchinfo import summary
|
||||||
@@ -24,18 +24,17 @@ class Model(torch.nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""GLVQ model for training on 2D Iris data."""
|
"""GLVQ model for training on 2D Iris data."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proto_layer = Prototypes1D(
|
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
||||||
input_dim=2,
|
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
|
||||||
prototypes_per_class=3,
|
self.proto_layer = LabeledComponents(
|
||||||
nclasses=3,
|
prototype_distribution,
|
||||||
prototype_initializer="stratified_random",
|
prototype_initializer,
|
||||||
data=[x_train, y_train])
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.proto_layer.prototypes
|
prototypes, prototype_labels = self.proto_layer()
|
||||||
plabels = self.proto_layer.prototype_labels
|
distances = euclidean_distance(x, prototypes)
|
||||||
dis = euclidean_distance(x, protos)
|
return distances, prototype_labels
|
||||||
return dis, plabels
|
|
||||||
|
|
||||||
|
|
||||||
# Build the GLVQ model
|
# Build the GLVQ model
|
||||||
@@ -52,47 +51,54 @@ x_in = torch.Tensor(x_train)
|
|||||||
y_in = torch.Tensor(y_train)
|
y_in = torch.Tensor(y_train)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
title = "Prototype Visualization"
|
TITLE = "Prototype Visualization"
|
||||||
fig = plt.figure(title)
|
fig = plt.figure(TITLE)
|
||||||
for epoch in range(70):
|
for epoch in range(70):
|
||||||
# Compute loss
|
# Compute loss
|
||||||
dis, plabels = model(x_in)
|
distances, prototype_labels = model(x_in)
|
||||||
loss = criterion([dis, plabels], y_in)
|
loss = criterion([distances, prototype_labels], y_in)
|
||||||
with torch.no_grad():
|
|
||||||
pred = wtac(dis, plabels)
|
|
||||||
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
|
||||||
acc = 100. * correct / len(x_train)
|
|
||||||
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%")
|
|
||||||
|
|
||||||
# Take a gradient descent step
|
# Compute Accuracy
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions = wtac(distances, prototype_labels)
|
||||||
|
correct = predictions.eq(y_in.view_as(predictions)).sum().item()
|
||||||
|
acc = 100.0 * correct / len(x_train)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer step
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# Get the prototypes form the model
|
# Get the prototypes form the model
|
||||||
protos = model.proto_layer.prototypes.data.numpy()
|
prototypes = model.proto_layer.components.numpy()
|
||||||
if np.isnan(np.sum(protos)):
|
if np.isnan(np.sum(prototypes)):
|
||||||
print("Stopping training because of `nan` in prototypes.")
|
print("Stopping training because of `nan` in prototypes.")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
# Visualize the data and the prototypes
|
||||||
ax = fig.gca()
|
ax = fig.gca()
|
||||||
ax.cla()
|
ax.cla()
|
||||||
ax.set_title(title)
|
ax.set_title(TITLE)
|
||||||
ax.set_xlabel("Data dimension 1")
|
ax.set_xlabel("Data dimension 1")
|
||||||
ax.set_ylabel("Data dimension 2")
|
ax.set_ylabel("Data dimension 2")
|
||||||
cmap = "viridis"
|
cmap = "viridis"
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||||
ax.scatter(protos[:, 0],
|
ax.scatter(
|
||||||
protos[:, 1],
|
prototypes[:, 0],
|
||||||
c=plabels,
|
prototypes[:, 1],
|
||||||
cmap=cmap,
|
c=prototype_labels,
|
||||||
edgecolor="k",
|
cmap=cmap,
|
||||||
marker="D",
|
edgecolor="k",
|
||||||
s=50)
|
marker="D",
|
||||||
|
s=50,
|
||||||
|
)
|
||||||
|
|
||||||
# Paint decision regions
|
# Paint decision regions
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, prototypes))
|
||||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
||||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
||||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
||||||
@@ -102,7 +108,7 @@ for epoch in range(70):
|
|||||||
torch_input = torch.Tensor(mesh_input)
|
torch_input = torch.Tensor(mesh_input)
|
||||||
d = model(torch_input)[0]
|
d = model(torch_input)[0]
|
||||||
w_indices = torch.argmin(d, dim=1)
|
w_indices = torch.argmin(d, dim=1)
|
||||||
y_pred = torch.index_select(plabels, 0, w_indices)
|
y_pred = torch.index_select(prototype_labels, 0, w_indices)
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
y_pred = y_pred.reshape(xx.shape)
|
||||||
|
|
||||||
# Plot voronoi regions
|
# Plot voronoi regions
|
||||||
|
@@ -2,13 +2,12 @@
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
|
|
||||||
from prototorch.datasets.tecator import Tecator
|
from prototorch.datasets.tecator import Tecator
|
||||||
from prototorch.functions.distances import sed
|
from prototorch.functions.distances import sed
|
||||||
from prototorch.modules import Prototypes1D
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.utils.colors import get_legend_handles
|
from prototorch.utils.colors import get_legend_handles
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
# Prepare the dataset and dataloader
|
# Prepare the dataset and dataloader
|
||||||
train_data = Tecator(root="./artifacts", train=True)
|
train_data = Tecator(root="./artifacts", train=True)
|
||||||
@@ -19,20 +18,22 @@ class Model(torch.nn.Module):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""GMLVQ model as a siamese network."""
|
"""GMLVQ model as a siamese network."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
x, y = train_data.data, train_data.targets
|
prototype_initializer = StratifiedMeanInitializer(train_loader)
|
||||||
self.p1 = Prototypes1D(input_dim=100,
|
prototype_distribution = {"num_classes": 2, "prototypes_per_class": 2}
|
||||||
prototypes_per_class=2,
|
|
||||||
nclasses=2,
|
self.proto_layer = LabeledComponents(
|
||||||
prototype_initializer="stratified_random",
|
prototype_distribution,
|
||||||
data=[x, y])
|
prototype_initializer,
|
||||||
|
)
|
||||||
|
|
||||||
self.omega = torch.nn.Linear(in_features=100,
|
self.omega = torch.nn.Linear(in_features=100,
|
||||||
out_features=100,
|
out_features=100,
|
||||||
bias=False)
|
bias=False)
|
||||||
torch.nn.init.eye_(self.omega.weight)
|
torch.nn.init.eye_(self.omega.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.p1.prototypes
|
protos = self.proto_layer.components
|
||||||
plabels = self.p1.prototype_labels
|
plabels = self.proto_layer.component_labels
|
||||||
|
|
||||||
# Process `x` and `protos` through `omega`
|
# Process `x` and `protos` through `omega`
|
||||||
x_map = self.omega(x)
|
x_map = self.omega(x)
|
||||||
@@ -84,8 +85,8 @@ im = ax.imshow(omega.dot(omega.T), cmap="viridis")
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
# Get the prototypes form the model
|
# Get the prototypes form the model
|
||||||
protos = model.p1.prototypes.data.numpy()
|
protos = model.proto_layer.components.numpy()
|
||||||
plabels = model.p1.prototype_labels
|
plabels = model.proto_layer.component_labels.numpy()
|
||||||
|
|
||||||
# Visualize the prototypes
|
# Visualize the prototypes
|
||||||
title = "Tecator Prototypes"
|
title = "Tecator Prototypes"
|
||||||
|
@@ -12,46 +12,54 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision import transforms
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||||
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.models import GTLVQ
|
from prototorch.modules.models import GTLVQ
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
# Parameters and options
|
# Parameters and options
|
||||||
n_epochs = 50
|
num_epochs = 50
|
||||||
batch_size_train = 64
|
batch_size_train = 64
|
||||||
batch_size_test = 1000
|
batch_size_test = 1000
|
||||||
learning_rate = 0.1
|
learning_rate = 0.1
|
||||||
momentum = 0.5
|
momentum = 0.5
|
||||||
log_interval = 10
|
log_interval = 10
|
||||||
cuda = "cuda:1"
|
cuda = "cuda:0"
|
||||||
random_seed = 1
|
random_seed = 1
|
||||||
device = torch.device(cuda if torch.cuda.is_available() else 'cpu')
|
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Configures reproducability
|
# Configures reproducability
|
||||||
torch.manual_seed(random_seed)
|
torch.manual_seed(random_seed)
|
||||||
np.random.seed(random_seed)
|
np.random.seed(random_seed)
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
# Prepare and preprocess the data
|
||||||
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
'./files/',
|
torchvision.datasets.MNIST(
|
||||||
train=True,
|
"./files/",
|
||||||
download=True,
|
train=True,
|
||||||
transform=torchvision.transforms.Compose(
|
download=True,
|
||||||
[transforms.ToTensor(),
|
transform=torchvision.transforms.Compose([
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
transforms.ToTensor(),
|
||||||
batch_size=batch_size_train,
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||||
shuffle=True)
|
]),
|
||||||
|
),
|
||||||
|
batch_size=batch_size_train,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
test_loader = torch.utils.data.DataLoader(
|
||||||
'./files/',
|
torchvision.datasets.MNIST(
|
||||||
train=False,
|
"./files/",
|
||||||
download=True,
|
train=False,
|
||||||
transform=torchvision.transforms.Compose(
|
download=True,
|
||||||
[transforms.ToTensor(),
|
transform=torchvision.transforms.Compose([
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
transforms.ToTensor(),
|
||||||
batch_size=batch_size_test,
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||||
shuffle=True)
|
]),
|
||||||
|
),
|
||||||
|
batch_size=batch_size_test,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Define the GLVQ model plus appropriate feature extractor
|
# Define the GLVQ model plus appropriate feature extractor
|
||||||
@@ -67,25 +75,34 @@ class CNNGTLVQ(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
super(CNNGTLVQ, self).__init__()
|
super(CNNGTLVQ, self).__init__()
|
||||||
|
|
||||||
#Feature Extractor - Simple CNN
|
# Feature Extractor - Simple CNN
|
||||||
self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
|
self.fe = nn.Sequential(
|
||||||
nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
|
nn.Conv2d(1, 32, 3, 1),
|
||||||
nn.MaxPool2d(2), nn.Dropout(0.25),
|
nn.ReLU(),
|
||||||
nn.Flatten(), nn.Linear(9216, bottleneck_dim),
|
nn.Conv2d(32, 64, 3, 1),
|
||||||
nn.Dropout(0.5), nn.LeakyReLU(),
|
nn.ReLU(),
|
||||||
nn.LayerNorm(bottleneck_dim))
|
nn.MaxPool2d(2),
|
||||||
|
nn.Dropout(0.25),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(9216, bottleneck_dim),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.LayerNorm(bottleneck_dim),
|
||||||
|
)
|
||||||
|
|
||||||
# Forward pass of subspace and prototype initialization data through feature extractor
|
# Forward pass of subspace and prototype initialization data through feature extractor
|
||||||
subspace_data = self.fe(subspace_data)
|
subspace_data = self.fe(subspace_data)
|
||||||
prototype_data[0] = self.fe(prototype_data[0])
|
prototype_data[0] = self.fe(prototype_data[0])
|
||||||
|
|
||||||
# Initialization of GTLVQ
|
# Initialization of GTLVQ
|
||||||
self.gtlvq = GTLVQ(num_classes,
|
self.gtlvq = GTLVQ(
|
||||||
subspace_data,
|
num_classes,
|
||||||
prototype_data,
|
subspace_data,
|
||||||
tangent_projection_type=tangent_projection_type,
|
prototype_data,
|
||||||
feature_dim=bottleneck_dim,
|
tangent_projection_type=tangent_projection_type,
|
||||||
prototypes_per_class=prototypes_per_class)
|
feature_dim=bottleneck_dim,
|
||||||
|
prototypes_per_class=prototypes_per_class,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Feature Extraction
|
# Feature Extraction
|
||||||
@@ -103,30 +120,34 @@ subspace_data = torch.cat(
|
|||||||
prototype_data = next(iter(train_loader))
|
prototype_data = next(iter(train_loader))
|
||||||
|
|
||||||
# Build the CNN GTLVQ model
|
# Build the CNN GTLVQ model
|
||||||
model = CNNGTLVQ(10,
|
model = CNNGTLVQ(
|
||||||
subspace_data,
|
10,
|
||||||
prototype_data,
|
subspace_data,
|
||||||
tangent_projection_type="local",
|
prototype_data,
|
||||||
bottleneck_dim=128).to(device)
|
tangent_projection_type="local",
|
||||||
|
bottleneck_dim=128,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
# Optimize using SGD optimizer from `torch.optim`
|
# Optimize using SGD optimizer from `torch.optim`
|
||||||
optimizer = torch.optim.Adam([{
|
optimizer = torch.optim.Adam(
|
||||||
'params': model.fe.parameters()
|
[{
|
||||||
}, {
|
"params": model.fe.parameters()
|
||||||
'params': model.gtlvq.parameters()
|
}, {
|
||||||
}],
|
"params": model.gtlvq.parameters()
|
||||||
lr=learning_rate)
|
}],
|
||||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
lr=learning_rate,
|
||||||
|
)
|
||||||
|
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
for epoch in range(n_epochs):
|
for epoch in range(num_epochs):
|
||||||
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
||||||
model.train()
|
model.train()
|
||||||
x_train, y_train = x_train.to(device), y_train.to(device)
|
x_train, y_train = x_train.to(device), y_train.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
distances = model(x_train)
|
distances = model(x_train)
|
||||||
plabels = model.gtlvq.cls.prototype_labels.to(device)
|
plabels = model.gtlvq.cls.component_labels.to(device)
|
||||||
|
|
||||||
# Compute loss.
|
# Compute loss.
|
||||||
loss = criterion([distances, plabels], y_train)
|
loss = criterion([distances, plabels], y_train)
|
||||||
@@ -139,8 +160,8 @@ for epoch in range(n_epochs):
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
||||||
print(
|
print(
|
||||||
f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||||
Train Acc: {acc.item():02.02f}')
|
Train Acc: {acc.item():02.02f}")
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -154,9 +175,9 @@ for epoch in range(n_epochs):
|
|||||||
i = torch.argmin(test_distances, 1)
|
i = torch.argmin(test_distances, 1)
|
||||||
correct += torch.sum(y_test == test_plabels[i])
|
correct += torch.sum(y_test == test_plabels[i])
|
||||||
total += y_test.size(0)
|
total += y_test.size(0)
|
||||||
print('Accuracy of the network on the test images: %d %%' %
|
print("Accuracy of the network on the test images: %d %%" %
|
||||||
(torch.true_divide(correct, total) * 100))
|
(torch.true_divide(correct, total) * 100))
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
PATH = './glvq_mnist_model.pth'
|
PATH = "./glvq_mnist_model.pth"
|
||||||
torch.save(model.state_dict(), PATH)
|
torch.save(model.state_dict(), PATH)
|
||||||
|
@@ -3,14 +3,12 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from sklearn.datasets import load_iris
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
from sklearn.metrics import accuracy_score
|
|
||||||
|
|
||||||
from prototorch.functions.competitions import stratified_min
|
from prototorch.functions.competitions import stratified_min
|
||||||
from prototorch.functions.distances import lomega_distance
|
from prototorch.functions.distances import lomega_distance
|
||||||
from prototorch.functions.init import eye_
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.metrics import accuracy_score
|
||||||
|
|
||||||
# Prepare training data
|
# Prepare training data
|
||||||
x_train, y_train = load_iris(True)
|
x_train, y_train = load_iris(True)
|
||||||
@@ -22,17 +20,19 @@ class Model(torch.nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Local-GMLVQ model."""
|
"""Local-GMLVQ model."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.p1 = Prototypes1D(input_dim=2,
|
|
||||||
prototype_distribution=[1, 2, 2],
|
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
||||||
prototype_initializer="stratified_random",
|
prototype_distribution = [1, 2, 2]
|
||||||
data=[x_train, y_train])
|
self.proto_layer = LabeledComponents(
|
||||||
omegas = torch.zeros(5, 2, 2)
|
prototype_distribution,
|
||||||
|
prototype_initializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
omegas = torch.eye(2, 2).repeat(5, 1, 1)
|
||||||
self.omegas = torch.nn.Parameter(omegas)
|
self.omegas = torch.nn.Parameter(omegas)
|
||||||
eye_(self.omegas)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.p1.prototypes
|
protos, plabels = self.proto_layer()
|
||||||
plabels = self.p1.prototype_labels
|
|
||||||
omegas = self.omegas
|
omegas = self.omegas
|
||||||
dis = lomega_distance(x, protos, omegas)
|
dis = lomega_distance(x, protos, omegas)
|
||||||
return dis, plabels
|
return dis, plabels
|
||||||
@@ -67,7 +67,7 @@ for epoch in range(100):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# Get the prototypes form the model
|
# Get the prototypes form the model
|
||||||
protos = model.p1.prototypes.data.numpy()
|
protos = model.proto_layer.components.numpy()
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
# Visualize the data and the prototypes
|
||||||
ax = fig.gca()
|
ax = fig.gca()
|
||||||
@@ -76,14 +76,16 @@ for epoch in range(100):
|
|||||||
ax.set_xlabel("Data dimension 1")
|
ax.set_xlabel("Data dimension 1")
|
||||||
ax.set_ylabel("Data dimension 2")
|
ax.set_ylabel("Data dimension 2")
|
||||||
cmap = "viridis"
|
cmap = "viridis"
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||||
ax.scatter(protos[:, 0],
|
ax.scatter(
|
||||||
protos[:, 1],
|
protos[:, 0],
|
||||||
c=plabels,
|
protos[:, 1],
|
||||||
cmap=cmap,
|
c=plabels,
|
||||||
edgecolor='k',
|
cmap=cmap,
|
||||||
marker='D',
|
edgecolor="k",
|
||||||
s=50)
|
marker="D",
|
||||||
|
s=50,
|
||||||
|
)
|
||||||
|
|
||||||
# Paint decision regions
|
# Paint decision regions
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
|
65
examples/new_components.py
Normal file
65
examples/new_components.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""This example script shows the usage of the new components architecture.
|
||||||
|
|
||||||
|
Serialization/deserialization also works as expected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# DATASET
|
||||||
|
import torch
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
x_train, y_train = load_iris(return_X_y=True)
|
||||||
|
x_train = x_train[:, [0, 2]]
|
||||||
|
scaler.fit(x_train)
|
||||||
|
x_train = scaler.transform(x_train)
|
||||||
|
|
||||||
|
x_train = torch.Tensor(x_train)
|
||||||
|
y_train = torch.Tensor(y_train)
|
||||||
|
num_classes = len(torch.unique(y_train))
|
||||||
|
|
||||||
|
# CREATE NEW COMPONENTS
|
||||||
|
from prototorch.components import *
|
||||||
|
from prototorch.components.initializers import *
|
||||||
|
|
||||||
|
unsupervised = Components(6, SelectionInitializer(x_train))
|
||||||
|
print(unsupervised())
|
||||||
|
|
||||||
|
prototypes = LabeledComponents(
|
||||||
|
(3, 2), StratifiedSelectionInitializer(x_train, y_train))
|
||||||
|
print(prototypes())
|
||||||
|
|
||||||
|
components = ReasoningComponents(
|
||||||
|
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
|
||||||
|
print(components())
|
||||||
|
|
||||||
|
# TEST SERIALIZATION
|
||||||
|
import io
|
||||||
|
|
||||||
|
save = io.BytesIO()
|
||||||
|
torch.save(unsupervised, save)
|
||||||
|
save.seek(0)
|
||||||
|
serialized_unsupervised = torch.load(save)
|
||||||
|
|
||||||
|
assert torch.all(unsupervised.components == serialized_unsupervised.components
|
||||||
|
), "Serialization of Components failed."
|
||||||
|
|
||||||
|
save = io.BytesIO()
|
||||||
|
torch.save(prototypes, save)
|
||||||
|
save.seek(0)
|
||||||
|
serialized_prototypes = torch.load(save)
|
||||||
|
|
||||||
|
assert torch.all(prototypes.components == serialized_prototypes.components
|
||||||
|
), "Serialization of Components failed."
|
||||||
|
assert torch.all(prototypes.component_labels == serialized_prototypes.
|
||||||
|
component_labels), "Serialization of Components failed."
|
||||||
|
|
||||||
|
save = io.BytesIO()
|
||||||
|
torch.save(components, save)
|
||||||
|
save.seek(0)
|
||||||
|
serialized_components = torch.load(save)
|
||||||
|
|
||||||
|
assert torch.all(components.components == serialized_components.components
|
||||||
|
), "Serialization of Components failed."
|
||||||
|
assert torch.all(components.reasonings == serialized_components.reasonings
|
||||||
|
), "Serialization of Components failed."
|
@@ -1,31 +1,32 @@
|
|||||||
"""ProtoTorch package."""
|
"""ProtoTorch package."""
|
||||||
|
|
||||||
# #############################################
|
import pkgutil
|
||||||
# Core Setup
|
|
||||||
# #############################################
|
|
||||||
__version__ = "0.3.0-dev0"
|
|
||||||
|
|
||||||
from prototorch import datasets, functions, modules
|
import pkg_resources
|
||||||
|
|
||||||
|
from . import components, datasets, functions, modules, utils
|
||||||
|
from .datasets import *
|
||||||
|
|
||||||
|
# Core Setup
|
||||||
|
__version__ = "0.5.0"
|
||||||
|
|
||||||
__all_core__ = [
|
__all_core__ = [
|
||||||
"datasets",
|
"datasets",
|
||||||
"functions",
|
"functions",
|
||||||
"modules",
|
"modules",
|
||||||
|
"components",
|
||||||
|
"utils",
|
||||||
]
|
]
|
||||||
|
|
||||||
# #############################################
|
|
||||||
# Plugin Loader
|
# Plugin Loader
|
||||||
# #############################################
|
|
||||||
import pkgutil
|
|
||||||
import pkg_resources
|
|
||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__)
|
__path__ = pkgutil.extend_path(__path__, __name__)
|
||||||
|
|
||||||
|
|
||||||
def discover_plugins():
|
def discover_plugins():
|
||||||
return {
|
return {
|
||||||
entry_point.name: entry_point.load()
|
entry_point.name: entry_point.load()
|
||||||
for entry_point in pkg_resources.iter_entry_points("prototorch.plugins")
|
for entry_point in pkg_resources.iter_entry_points(
|
||||||
|
"prototorch.plugins")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -33,12 +34,10 @@ discovered_plugins = discover_plugins()
|
|||||||
locals().update(discovered_plugins)
|
locals().update(discovered_plugins)
|
||||||
|
|
||||||
# Generate combines __version__ and __all__
|
# Generate combines __version__ and __all__
|
||||||
version_plugins = "\n".join(
|
version_plugins = "\n".join([
|
||||||
[
|
"- " + name + ": v" + plugin.__version__
|
||||||
"- " + name + ": v" + plugin.__version__
|
for name, plugin in discovered_plugins.items()
|
||||||
for name, plugin in discovered_plugins.items()
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
if version_plugins != "":
|
if version_plugins != "":
|
||||||
version_plugins = "\nPlugins: \n" + version_plugins
|
version_plugins = "\nPlugins: \n" + version_plugins
|
||||||
|
|
||||||
|
2
prototorch/components/__init__.py
Normal file
2
prototorch/components/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from prototorch.components.components import *
|
||||||
|
from prototorch.components.initializers import *
|
229
prototorch/components/components.py
Normal file
229
prototorch/components/components.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""ProtoTorch components modules."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||||
|
ComponentsInitializer,
|
||||||
|
CustomLabelsInitializer,
|
||||||
|
EqualLabelsInitializer,
|
||||||
|
UnequalLabelsInitializer,
|
||||||
|
ZeroReasoningsInitializer)
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from .initializers import parse_data_arg
|
||||||
|
|
||||||
|
|
||||||
|
def get_labels_object(distribution):
|
||||||
|
if isinstance(distribution, dict):
|
||||||
|
if "num_classes" in distribution.keys():
|
||||||
|
labels = EqualLabelsInitializer(
|
||||||
|
distribution["num_classes"],
|
||||||
|
distribution["prototypes_per_class"])
|
||||||
|
else:
|
||||||
|
labels = CustomLabelsInitializer(distribution)
|
||||||
|
elif isinstance(distribution, tuple):
|
||||||
|
num_classes, prototypes_per_class = distribution
|
||||||
|
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||||
|
elif isinstance(distribution, list):
|
||||||
|
labels = UnequalLabelsInitializer(distribution)
|
||||||
|
else:
|
||||||
|
msg = f"`distribution` not understood." \
|
||||||
|
f"You have provided: {distribution=}."
|
||||||
|
raise ValueError(msg)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def _precheck_initializer(initializer):
|
||||||
|
if not isinstance(initializer, ComponentsInitializer):
|
||||||
|
emsg = f"`initializer` has to be some subtype of " \
|
||||||
|
f"{ComponentsInitializer}. " \
|
||||||
|
f"You have provided: {initializer=} instead."
|
||||||
|
raise TypeError(emsg)
|
||||||
|
|
||||||
|
|
||||||
|
class Components(torch.nn.Module):
|
||||||
|
"""Components is a set of learnable Tensors."""
|
||||||
|
def __init__(self,
|
||||||
|
num_components=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_components=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Ignore all initialization settings if initialized_components is given.
|
||||||
|
if initialized_components is not None:
|
||||||
|
self._register_components(initialized_components)
|
||||||
|
if num_components is not None or initializer is not None:
|
||||||
|
wmsg = "Arguments ignored while initializing Components"
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
else:
|
||||||
|
self._initialize_components(num_components, initializer)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_components(self):
|
||||||
|
return len(self._components)
|
||||||
|
|
||||||
|
def _register_components(self, components):
|
||||||
|
self.register_parameter("_components", Parameter(components))
|
||||||
|
|
||||||
|
def _initialize_components(self, num_components, initializer):
|
||||||
|
_precheck_initializer(initializer)
|
||||||
|
_components = initializer.generate(num_components)
|
||||||
|
self._register_components(_components)
|
||||||
|
|
||||||
|
def add_components(self,
|
||||||
|
num=1,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_components=None):
|
||||||
|
if initialized_components is not None:
|
||||||
|
_components = torch.cat([self._components, initialized_components])
|
||||||
|
else:
|
||||||
|
_precheck_initializer(initializer)
|
||||||
|
_new = initializer.generate(num)
|
||||||
|
_components = torch.cat([self._components, _new])
|
||||||
|
self._register_components(_components)
|
||||||
|
|
||||||
|
def remove_components(self, indices=None):
|
||||||
|
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||||
|
mask[indices] = False
|
||||||
|
_components = self._components[mask]
|
||||||
|
self._register_components(_components)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
@property
|
||||||
|
def components(self):
|
||||||
|
"""Tensor containing the component tensors."""
|
||||||
|
return self._components.detach()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self._components
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"(components): (shape: {tuple(self._components.shape)})"
|
||||||
|
|
||||||
|
|
||||||
|
class LabeledComponents(Components):
|
||||||
|
"""LabeledComponents generate a set of components and a set of labels.
|
||||||
|
|
||||||
|
Every Component has a label assigned.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
distribution=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_components=None):
|
||||||
|
if initialized_components is not None:
|
||||||
|
components, component_labels = parse_data_arg(
|
||||||
|
initialized_components)
|
||||||
|
super().__init__(initialized_components=components)
|
||||||
|
self._labels = component_labels
|
||||||
|
else:
|
||||||
|
labels = get_labels_object(distribution)
|
||||||
|
self.initial_distribution = labels.distribution
|
||||||
|
_labels = labels.generate()
|
||||||
|
super().__init__(len(_labels), initializer=initializer)
|
||||||
|
self._register_labels(_labels)
|
||||||
|
|
||||||
|
def _register_labels(self, labels):
|
||||||
|
self.register_buffer("_labels", labels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def distribution(self):
|
||||||
|
clabels, counts = torch.unique(self._labels,
|
||||||
|
sorted=True,
|
||||||
|
return_counts=True)
|
||||||
|
return dict(zip(clabels.tolist(), counts.tolist()))
|
||||||
|
|
||||||
|
def _initialize_components(self, num_components, initializer):
|
||||||
|
if isinstance(initializer, ClassAwareInitializer):
|
||||||
|
_precheck_initializer(initializer)
|
||||||
|
_components = initializer.generate(num_components,
|
||||||
|
self.initial_distribution)
|
||||||
|
self._register_components(_components)
|
||||||
|
else:
|
||||||
|
super()._initialize_components(num_components, initializer)
|
||||||
|
|
||||||
|
def add_components(self, distribution, initializer):
|
||||||
|
_precheck_initializer(initializer)
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
labels = get_labels_object(distribution)
|
||||||
|
new_labels = labels.generate()
|
||||||
|
_labels = torch.cat([self._labels, new_labels])
|
||||||
|
self._register_labels(_labels)
|
||||||
|
|
||||||
|
# Components
|
||||||
|
if isinstance(initializer, ClassAwareInitializer):
|
||||||
|
_new = initializer.generate(len(new_labels), labels.distribution)
|
||||||
|
else:
|
||||||
|
_new = initializer.generate(len(new_labels))
|
||||||
|
_components = torch.cat([self._components, _new])
|
||||||
|
self._register_components(_components)
|
||||||
|
|
||||||
|
def remove_components(self, indices=None):
|
||||||
|
# Components
|
||||||
|
mask = super().remove_components(indices)
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
_labels = self._labels[mask]
|
||||||
|
self._register_labels(_labels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def component_labels(self):
|
||||||
|
"""Tensor containing the component tensors."""
|
||||||
|
return self._labels.detach()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return super().forward(), self._labels
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningComponents(Components):
|
||||||
|
"""ReasoningComponents generate a set of components and a set of reasoning matrices.
|
||||||
|
|
||||||
|
Every Component has a reasoning matrix assigned.
|
||||||
|
|
||||||
|
A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The
|
||||||
|
first element is called positive reasoning :math:`p`, the second negative
|
||||||
|
reasoning :math:`n`. A components can reason in favour (positive) of a
|
||||||
|
class, against (negative) a class or not at all (neutral).
|
||||||
|
|
||||||
|
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
|
||||||
|
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
|
||||||
|
three element probability distribution.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
reasonings=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_components=None):
|
||||||
|
if initialized_components is not None:
|
||||||
|
components, reasonings = initialized_components
|
||||||
|
|
||||||
|
super().__init__(initialized_components=components)
|
||||||
|
self.register_parameter("_reasonings", reasonings)
|
||||||
|
else:
|
||||||
|
self._initialize_reasonings(reasonings)
|
||||||
|
super().__init__(len(self._reasonings), initializer=initializer)
|
||||||
|
|
||||||
|
def _initialize_reasonings(self, reasonings):
|
||||||
|
if isinstance(reasonings, tuple):
|
||||||
|
num_classes, num_components = reasonings
|
||||||
|
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
||||||
|
|
||||||
|
_reasonings = reasonings.generate()
|
||||||
|
self.register_parameter("_reasonings", _reasonings)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reasonings(self):
|
||||||
|
"""Returns Reasoning Matrix.
|
||||||
|
|
||||||
|
Dimension NxCx2
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self._reasonings.detach()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return super().forward(), self._reasonings
|
234
prototorch/components/initializers.py
Normal file
234
prototorch/components/initializers.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""ProtoTroch Component and Label Initializers."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def parse_data_arg(data_arg):
|
||||||
|
if isinstance(data_arg, Dataset):
|
||||||
|
data_arg = DataLoader(data_arg, batch_size=len(data_arg))
|
||||||
|
|
||||||
|
if isinstance(data_arg, DataLoader):
|
||||||
|
data = torch.tensor([])
|
||||||
|
targets = torch.tensor([])
|
||||||
|
for x, y in data_arg:
|
||||||
|
data = torch.cat([data, x])
|
||||||
|
targets = torch.cat([targets, y])
|
||||||
|
else:
|
||||||
|
data, targets = data_arg
|
||||||
|
if not isinstance(data, torch.Tensor):
|
||||||
|
wmsg = f"Converting data to {torch.Tensor}."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
data = torch.Tensor(data)
|
||||||
|
if not isinstance(targets, torch.Tensor):
|
||||||
|
wmsg = f"Converting targets to {torch.Tensor}."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
targets = torch.Tensor(targets)
|
||||||
|
return data, targets
|
||||||
|
|
||||||
|
|
||||||
|
def get_subinitializers(data, targets, clabels, subinit_type):
|
||||||
|
initializers = dict()
|
||||||
|
for clabel in clabels:
|
||||||
|
class_data = data[targets == clabel]
|
||||||
|
class_initializer = subinit_type(class_data)
|
||||||
|
initializers[clabel] = (class_initializer)
|
||||||
|
return initializers
|
||||||
|
|
||||||
|
|
||||||
|
# Components
|
||||||
|
class ComponentsInitializer(object):
|
||||||
|
def generate(self, number_of_components):
|
||||||
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
|
class DimensionAwareInitializer(ComponentsInitializer):
|
||||||
|
def __init__(self, dims):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(dims, Iterable):
|
||||||
|
self.components_dims = tuple(dims)
|
||||||
|
else:
|
||||||
|
self.components_dims = (dims, )
|
||||||
|
|
||||||
|
|
||||||
|
class OnesInitializer(DimensionAwareInitializer):
|
||||||
|
def __init__(self, dims, scale=1.0):
|
||||||
|
super().__init__(dims)
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def generate(self, length):
|
||||||
|
gen_dims = (length, ) + self.components_dims
|
||||||
|
return torch.ones(gen_dims) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosInitializer(DimensionAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
gen_dims = (length, ) + self.components_dims
|
||||||
|
return torch.zeros(gen_dims)
|
||||||
|
|
||||||
|
|
||||||
|
class UniformInitializer(DimensionAwareInitializer):
|
||||||
|
def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0):
|
||||||
|
super().__init__(dims)
|
||||||
|
self.minimum = minimum
|
||||||
|
self.maximum = maximum
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def generate(self, length):
|
||||||
|
gen_dims = (length, ) + self.components_dims
|
||||||
|
return torch.ones(gen_dims).uniform_(self.minimum,
|
||||||
|
self.maximum) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class DataAwareInitializer(ComponentsInitializer):
|
||||||
|
def __init__(self, data, transform=torch.nn.Identity()):
|
||||||
|
super().__init__()
|
||||||
|
self.data = data
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self.data
|
||||||
|
|
||||||
|
|
||||||
|
class SelectionInitializer(DataAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||||
|
return self.transform(self.data[indices])
|
||||||
|
|
||||||
|
|
||||||
|
class MeanInitializer(DataAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
mean = torch.mean(self.data, dim=0)
|
||||||
|
repeat_dim = [length] + [1] * len(mean.shape)
|
||||||
|
return self.transform(mean.repeat(repeat_dim))
|
||||||
|
|
||||||
|
|
||||||
|
class ClassAwareInitializer(DataAwareInitializer):
|
||||||
|
def __init__(self, data, transform=torch.nn.Identity()):
|
||||||
|
data, targets = parse_data_arg(data)
|
||||||
|
super().__init__(data, transform)
|
||||||
|
self.targets = targets
|
||||||
|
self.clabels = torch.unique(self.targets).int().tolist()
|
||||||
|
self.num_classes = len(self.clabels)
|
||||||
|
|
||||||
|
def _get_samples_from_initializer(self, length, dist):
|
||||||
|
if not dist:
|
||||||
|
per_class = length // self.num_classes
|
||||||
|
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
||||||
|
if isinstance(dist, list):
|
||||||
|
dist = dict(zip(self.clabels, dist))
|
||||||
|
samples = [self.initializers[k].generate(n) for k, n in dist.items()]
|
||||||
|
out = torch.vstack(samples)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = self.transform(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self.data
|
||||||
|
del self.targets
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||||
|
def __init__(self, data, **kwargs):
|
||||||
|
super().__init__(data, **kwargs)
|
||||||
|
self.initializers = get_subinitializers(self.data, self.targets,
|
||||||
|
self.clabels, MeanInitializer)
|
||||||
|
|
||||||
|
def generate(self, length, dist):
|
||||||
|
samples = self._get_samples_from_initializer(length, dist)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||||
|
def __init__(self, data, noise=None, **kwargs):
|
||||||
|
super().__init__(data, **kwargs)
|
||||||
|
self.noise = noise
|
||||||
|
self.initializers = get_subinitializers(self.data, self.targets,
|
||||||
|
self.clabels,
|
||||||
|
SelectionInitializer)
|
||||||
|
|
||||||
|
def add_noise_v1(self, x):
|
||||||
|
return x + self.noise
|
||||||
|
|
||||||
|
def add_noise_v2(self, x):
|
||||||
|
"""Shifts some dimensions of the data randomly."""
|
||||||
|
n1 = torch.rand_like(x)
|
||||||
|
n2 = torch.rand_like(x)
|
||||||
|
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
|
||||||
|
return x + (self.noise * mask)
|
||||||
|
|
||||||
|
def generate(self, length, dist):
|
||||||
|
samples = self._get_samples_from_initializer(length, dist)
|
||||||
|
if self.noise is not None:
|
||||||
|
samples = self.add_noise_v1(samples)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
class LabelsInitializer:
|
||||||
|
def generate(self):
|
||||||
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
|
class UnequalLabelsInitializer(LabelsInitializer):
|
||||||
|
def __init__(self, dist):
|
||||||
|
self.dist = dist
|
||||||
|
|
||||||
|
@property
|
||||||
|
def distribution(self):
|
||||||
|
return self.dist
|
||||||
|
|
||||||
|
def generate(self, clabels=None, dist=None):
|
||||||
|
if not clabels:
|
||||||
|
clabels = range(len(self.dist))
|
||||||
|
if not dist:
|
||||||
|
dist = self.dist
|
||||||
|
targets = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
|
||||||
|
return torch.LongTensor(targets)
|
||||||
|
|
||||||
|
|
||||||
|
class EqualLabelsInitializer(LabelsInitializer):
|
||||||
|
def __init__(self, classes, per_class):
|
||||||
|
self.classes = classes
|
||||||
|
self.per_class = per_class
|
||||||
|
|
||||||
|
@property
|
||||||
|
def distribution(self):
|
||||||
|
return self.classes * [self.per_class]
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLabelsInitializer(UnequalLabelsInitializer):
|
||||||
|
def generate(self):
|
||||||
|
clabels = list(self.dist.keys())
|
||||||
|
dist = list(self.dist.values())
|
||||||
|
return super().generate(clabels, dist)
|
||||||
|
|
||||||
|
|
||||||
|
# Reasonings
|
||||||
|
class ReasoningsInitializer:
|
||||||
|
def generate(self, length):
|
||||||
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||||
|
def __init__(self, classes, length):
|
||||||
|
self.classes = classes
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
return torch.zeros((self.length, self.classes, 2))
|
||||||
|
|
||||||
|
|
||||||
|
# Aliases
|
||||||
|
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
||||||
|
SMI = StratifiedMeanInitializer
|
||||||
|
Random = RandomInitializer = UniformInitializer
|
||||||
|
Zeros = ZerosInitializer
|
||||||
|
Ones = OnesInitializer
|
@@ -1,7 +1,6 @@
|
|||||||
"""ProtoTorch datasets."""
|
"""ProtoTorch datasets."""
|
||||||
|
|
||||||
|
from .abstract import NumpyDataset
|
||||||
|
from .sklearn import Blobs, Circles, Iris, Moons, Random
|
||||||
|
from .spiral import Spiral
|
||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'Tecator',
|
|
||||||
]
|
|
||||||
|
@@ -12,6 +12,15 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||||
|
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||||
|
def __init__(self, data, targets):
|
||||||
|
self.data = torch.Tensor(data)
|
||||||
|
self.targets = torch.LongTensor(targets)
|
||||||
|
tensors = [self.data, self.targets]
|
||||||
|
super().__init__(*tensors)
|
||||||
|
|
||||||
|
|
||||||
class Dataset(torch.utils.data.Dataset):
|
class Dataset(torch.utils.data.Dataset):
|
||||||
"""Abstract dataset class to be inherited."""
|
"""Abstract dataset class to be inherited."""
|
||||||
|
|
||||||
@@ -44,15 +53,13 @@ class ProtoDataset(Dataset):
|
|||||||
self._download()
|
self._download()
|
||||||
|
|
||||||
if not self._check_exists():
|
if not self._check_exists():
|
||||||
raise RuntimeError(
|
raise RuntimeError("Dataset not found. "
|
||||||
"Dataset not found. " "You can use download=True to download it"
|
"You can use download=True to download it")
|
||||||
)
|
|
||||||
|
|
||||||
data_file = self.training_file if self.train else self.test_file
|
data_file = self.training_file if self.train else self.test_file
|
||||||
|
|
||||||
self.data, self.targets = torch.load(
|
self.data, self.targets = torch.load(
|
||||||
os.path.join(self.processed_folder, data_file)
|
os.path.join(self.processed_folder, data_file))
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def raw_folder(self):
|
def raw_folder(self):
|
||||||
@@ -68,8 +75,9 @@ class ProtoDataset(Dataset):
|
|||||||
|
|
||||||
def _check_exists(self):
|
def _check_exists(self):
|
||||||
return os.path.exists(
|
return os.path.exists(
|
||||||
os.path.join(self.processed_folder, self.training_file)
|
os.path.join(
|
||||||
) and os.path.exists(os.path.join(self.processed_folder, self.test_file))
|
self.processed_folder, self.training_file)) and os.path.exists(
|
||||||
|
os.path.join(self.processed_folder, self.test_file))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
head = "Dataset " + self.__class__.__name__
|
head = "Dataset " + self.__class__.__name__
|
||||||
|
137
prototorch/datasets/sklearn.py
Normal file
137
prototorch/datasets/sklearn.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Thin wrappers for a few scikit-learn datasets.
|
||||||
|
|
||||||
|
URL:
|
||||||
|
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from prototorch.datasets.abstract import NumpyDataset
|
||||||
|
|
||||||
|
from sklearn.datasets import (load_iris, make_blobs, make_circles,
|
||||||
|
make_classification, make_moons)
|
||||||
|
|
||||||
|
|
||||||
|
class Iris(NumpyDataset):
|
||||||
|
"""Iris Dataset by Ronald Fisher introduced in 1936.
|
||||||
|
|
||||||
|
The dataset contains four measurements from flowers of three species of iris.
|
||||||
|
|
||||||
|
.. list-table:: Iris
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - dimensions
|
||||||
|
- classes
|
||||||
|
- training size
|
||||||
|
- validation size
|
||||||
|
- test size
|
||||||
|
* - 4
|
||||||
|
- 3
|
||||||
|
- 150
|
||||||
|
- 0
|
||||||
|
- 0
|
||||||
|
|
||||||
|
:param dims: select a subset of dimensions
|
||||||
|
"""
|
||||||
|
def __init__(self, dims: Sequence[int] = None):
|
||||||
|
x, y = load_iris(return_X_y=True)
|
||||||
|
if dims:
|
||||||
|
x = x[:, dims]
|
||||||
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Blobs(NumpyDataset):
|
||||||
|
"""Generate isotropic Gaussian blobs for clustering.
|
||||||
|
|
||||||
|
Read more at
|
||||||
|
https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_samples: int = 300,
|
||||||
|
num_features: int = 2,
|
||||||
|
seed: Union[None, int] = 0):
|
||||||
|
x, y = make_blobs(num_samples,
|
||||||
|
num_features,
|
||||||
|
centers=None,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False)
|
||||||
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Random(NumpyDataset):
|
||||||
|
"""Generate a random n-class classification problem.
|
||||||
|
|
||||||
|
Read more at
|
||||||
|
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html.
|
||||||
|
|
||||||
|
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_samples: int = 300,
|
||||||
|
num_features: int = 2,
|
||||||
|
num_classes: int = 2,
|
||||||
|
num_clusters: int = 2,
|
||||||
|
num_informative: Union[None, int] = None,
|
||||||
|
separation: float = 1.0,
|
||||||
|
seed: Union[None, int] = 0):
|
||||||
|
if not num_informative:
|
||||||
|
import math
|
||||||
|
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
||||||
|
if num_features < num_informative:
|
||||||
|
warnings.warn("Generating more features than requested.")
|
||||||
|
num_features = num_informative
|
||||||
|
x, y = make_classification(num_samples,
|
||||||
|
num_features,
|
||||||
|
n_informative=num_informative,
|
||||||
|
n_redundant=0,
|
||||||
|
n_classes=num_classes,
|
||||||
|
n_clusters_per_class=num_clusters,
|
||||||
|
class_sep=separation,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False)
|
||||||
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Circles(NumpyDataset):
|
||||||
|
"""Make a large circle containing a smaller circle in 2D.
|
||||||
|
|
||||||
|
A simple toy dataset to visualize clustering and classification algorithms.
|
||||||
|
|
||||||
|
Read more at
|
||||||
|
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_samples: int = 300,
|
||||||
|
noise: float = 0.3,
|
||||||
|
factor: float = 0.8,
|
||||||
|
seed: Union[None, int] = 0):
|
||||||
|
x, y = make_circles(num_samples,
|
||||||
|
noise=noise,
|
||||||
|
factor=factor,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False)
|
||||||
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Moons(NumpyDataset):
|
||||||
|
"""Make two interleaving half circles.
|
||||||
|
|
||||||
|
A simple toy dataset to visualize clustering and classification algorithms.
|
||||||
|
|
||||||
|
Read more at
|
||||||
|
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_samples: int = 300,
|
||||||
|
noise: float = 0.3,
|
||||||
|
seed: Union[None, int] = 0):
|
||||||
|
x, y = make_moons(num_samples,
|
||||||
|
noise=noise,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False)
|
||||||
|
super().__init__(x, y)
|
57
prototorch/datasets/spiral.py
Normal file
57
prototorch/datasets/spiral.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Spiral dataset for binary classification."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_spiral(num_samples=500, noise=0.3):
|
||||||
|
"""Generates the Spiral Dataset.
|
||||||
|
|
||||||
|
For use in Prototorch use `prototorch.datasets.Spiral` instead.
|
||||||
|
"""
|
||||||
|
def get_samples(n, delta_t):
|
||||||
|
points = []
|
||||||
|
for i in range(n):
|
||||||
|
r = i / num_samples * 5
|
||||||
|
t = 1.75 * i / n * 2 * np.pi + delta_t
|
||||||
|
x = r * np.sin(t) + np.random.rand(1) * noise
|
||||||
|
y = r * np.cos(t) + np.random.rand(1) * noise
|
||||||
|
points.append([x, y])
|
||||||
|
return points
|
||||||
|
|
||||||
|
n = num_samples // 2
|
||||||
|
positive = get_samples(n=n, delta_t=0)
|
||||||
|
negative = get_samples(n=n, delta_t=np.pi)
|
||||||
|
x = np.concatenate(
|
||||||
|
[np.array(positive).reshape(n, -1),
|
||||||
|
np.array(negative).reshape(n, -1)],
|
||||||
|
axis=0)
|
||||||
|
y = np.concatenate([np.zeros(n), np.ones(n)])
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class Spiral(torch.utils.data.TensorDataset):
|
||||||
|
"""Spiral dataset for binary classification.
|
||||||
|
|
||||||
|
This datasets consists of two spirals of two different classes.
|
||||||
|
|
||||||
|
.. list-table:: Spiral
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - dimensions
|
||||||
|
- classes
|
||||||
|
- training size
|
||||||
|
- validation size
|
||||||
|
- test size
|
||||||
|
* - 2
|
||||||
|
- 2
|
||||||
|
- num_samples
|
||||||
|
- 0
|
||||||
|
- 0
|
||||||
|
|
||||||
|
:param num_samples: number of random samples
|
||||||
|
:param noise: noise added to the spirals
|
||||||
|
"""
|
||||||
|
def __init__(self, num_samples: int = 500, noise: float = 0.3):
|
||||||
|
x, y = make_spiral(num_samples, noise)
|
||||||
|
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
@@ -40,19 +40,34 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torchvision.datasets.utils import download_file_from_google_drive
|
|
||||||
|
|
||||||
from prototorch.datasets.abstract import ProtoDataset
|
from prototorch.datasets.abstract import ProtoDataset
|
||||||
|
from torchvision.datasets.utils import download_file_from_google_drive
|
||||||
|
|
||||||
|
|
||||||
class Tecator(ProtoDataset):
|
class Tecator(ProtoDataset):
|
||||||
"""
|
"""
|
||||||
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__
|
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__ for classification.
|
||||||
for classification.
|
|
||||||
|
The dataset contains wavelength measurements of meat.
|
||||||
|
|
||||||
|
.. list-table:: Tecator
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - dimensions
|
||||||
|
- classes
|
||||||
|
- training size
|
||||||
|
- validation size
|
||||||
|
- test size
|
||||||
|
* - 100
|
||||||
|
- 2
|
||||||
|
- 129
|
||||||
|
- 43
|
||||||
|
- 43
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_resources = [
|
_resources = [
|
||||||
("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0", "ba5607c580d0f91bb27dc29d13c2f8df"),
|
("1P9WIYnyxFPh6f1vqAbnKfK8oYmUgyV83",
|
||||||
|
"ba5607c580d0f91bb27dc29d13c2f8df"),
|
||||||
] # (google_storage_id, md5hash)
|
] # (google_storage_id, md5hash)
|
||||||
classes = ["0 - low_fat", "1 - high_fat"]
|
classes = ["0 - low_fat", "1 - high_fat"]
|
||||||
|
|
||||||
@@ -74,29 +89,31 @@ class Tecator(ProtoDataset):
|
|||||||
print("Downloading...")
|
print("Downloading...")
|
||||||
for fileid, md5 in self._resources:
|
for fileid, md5 in self._resources:
|
||||||
filename = "tecator.npz"
|
filename = "tecator.npz"
|
||||||
download_file_from_google_drive(
|
download_file_from_google_drive(fileid,
|
||||||
fileid, root=self.raw_folder, filename=filename, md5=md5
|
root=self.raw_folder,
|
||||||
)
|
filename=filename,
|
||||||
|
md5=md5)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Processing...")
|
print("Processing...")
|
||||||
with np.load(
|
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
||||||
os.path.join(self.raw_folder, "tecator.npz"), allow_pickle=False
|
allow_pickle=False) as f:
|
||||||
) as f:
|
|
||||||
x_train, y_train = f["x_train"], f["y_train"]
|
x_train, y_train = f["x_train"], f["y_train"]
|
||||||
x_test, y_test = f["x_test"], f["y_test"]
|
x_test, y_test = f["x_test"], f["y_test"]
|
||||||
training_set = [
|
training_set = [
|
||||||
torch.tensor(x_train, dtype=torch.float32),
|
torch.Tensor(x_train),
|
||||||
torch.tensor(y_train),
|
torch.LongTensor(y_train),
|
||||||
]
|
]
|
||||||
test_set = [
|
test_set = [
|
||||||
torch.tensor(x_test, dtype=torch.float32),
|
torch.Tensor(x_test),
|
||||||
torch.tensor(y_test),
|
torch.LongTensor(y_test),
|
||||||
]
|
]
|
||||||
|
|
||||||
with open(os.path.join(self.processed_folder, self.training_file), "wb") as f:
|
with open(os.path.join(self.processed_folder, self.training_file),
|
||||||
|
"wb") as f:
|
||||||
torch.save(training_set, f)
|
torch.save(training_set, f)
|
||||||
with open(os.path.join(self.processed_folder, self.test_file), "wb") as f:
|
with open(os.path.join(self.processed_folder, self.test_file),
|
||||||
|
"wb") as f:
|
||||||
torch.save(test_set, f)
|
torch.save(test_set, f)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@@ -2,11 +2,4 @@
|
|||||||
|
|
||||||
from .activations import identity, sigmoid_beta, swish_beta
|
from .activations import identity, sigmoid_beta, swish_beta
|
||||||
from .competitions import knnc, wtac
|
from .competitions import knnc, wtac
|
||||||
|
from .pooling import *
|
||||||
__all__ = [
|
|
||||||
'identity',
|
|
||||||
'sigmoid_beta',
|
|
||||||
'swish_beta',
|
|
||||||
'knnc',
|
|
||||||
'wtac',
|
|
||||||
]
|
|
||||||
|
@@ -5,51 +5,49 @@ import torch
|
|||||||
ACTIVATIONS = dict()
|
ACTIVATIONS = dict()
|
||||||
|
|
||||||
|
|
||||||
# def register_activation(scriptf):
|
def register_activation(fn):
|
||||||
# ACTIVATIONS[scriptf.name] = scriptf
|
|
||||||
# return scriptf
|
|
||||||
def register_activation(function):
|
|
||||||
"""Add the activation function to the registry."""
|
"""Add the activation function to the registry."""
|
||||||
ACTIVATIONS[function.__name__] = function
|
name = fn.__name__
|
||||||
return function
|
ACTIVATIONS[name] = fn
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
# @torch.jit.script
|
def identity(x, beta=0.0):
|
||||||
def identity(x, beta=torch.tensor(0)):
|
|
||||||
"""Identity activation function.
|
"""Identity activation function.
|
||||||
|
|
||||||
Definition:
|
Definition:
|
||||||
:math:`f(x) = x`
|
:math:`f(x) = x`
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
beta (`float`): Ignored.
|
||||||
"""
|
"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
# @torch.jit.script
|
def sigmoid_beta(x, beta=10.0):
|
||||||
def sigmoid_beta(x, beta=torch.tensor(10)):
|
|
||||||
r"""Sigmoid activation function with scaling.
|
r"""Sigmoid activation function with scaling.
|
||||||
|
|
||||||
Definition:
|
Definition:
|
||||||
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
|
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
beta (`torch.tensor`): Scaling parameter :math:`\beta`
|
beta (`float`): Scaling parameter :math:`\beta`
|
||||||
"""
|
"""
|
||||||
out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * x))
|
out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
# @torch.jit.script
|
def swish_beta(x, beta=10.0):
|
||||||
def swish_beta(x, beta=torch.tensor(10)):
|
|
||||||
r"""Swish activation function with scaling.
|
r"""Swish activation function with scaling.
|
||||||
|
|
||||||
Definition:
|
Definition:
|
||||||
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
|
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
beta (`torch.tensor`): Scaling parameter :math:`\beta`
|
beta (`float`): Scaling parameter :math:`\beta`
|
||||||
"""
|
"""
|
||||||
out = x * sigmoid_beta(x, beta=beta)
|
out = x * sigmoid_beta(x, beta=beta)
|
||||||
return out
|
return out
|
||||||
@@ -61,4 +59,4 @@ def get_activation(funcname):
|
|||||||
return funcname
|
return funcname
|
||||||
if funcname in ACTIVATIONS:
|
if funcname in ACTIVATIONS:
|
||||||
return ACTIVATIONS.get(funcname)
|
return ACTIVATIONS.get(funcname)
|
||||||
raise NameError(f'Activation {funcname} was not found.')
|
raise NameError(f"Activation {funcname} was not found.")
|
||||||
|
@@ -3,43 +3,26 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
# @torch.jit.script
|
def wtac(distances: torch.Tensor,
|
||||||
def stratified_min(distances, labels):
|
labels: torch.LongTensor) -> (torch.LongTensor):
|
||||||
clabels = torch.unique(labels, dim=0)
|
"""Winner-Takes-All-Competition.
|
||||||
nclasses = clabels.size()[0]
|
|
||||||
if distances.size()[1] == nclasses:
|
|
||||||
# skip if only one prototype per class
|
|
||||||
return distances
|
|
||||||
batch_size = distances.size()[0]
|
|
||||||
winning_distances = torch.zeros(nclasses, batch_size)
|
|
||||||
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
|
||||||
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
|
||||||
for i, cl in enumerate(clabels):
|
|
||||||
# cdists = distances.T[labels == cl]
|
|
||||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
|
||||||
if labels.ndim == 2:
|
|
||||||
# if the labels are one-hot vectors
|
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
|
||||||
cdists = torch.where(matcher, distances.T, inf).T
|
|
||||||
winning_distances[i] = torch.min(cdists, dim=1,
|
|
||||||
keepdim=True).values.squeeze()
|
|
||||||
if labels.ndim == 2:
|
|
||||||
# Transpose to return with `batch_size` first and
|
|
||||||
# reverse the columns to fix the ordering of the classes
|
|
||||||
return torch.flip(winning_distances.T, dims=(1, ))
|
|
||||||
|
|
||||||
return winning_distances.T # return with `batch_size` first
|
Returns the labels corresponding to the winners.
|
||||||
|
|
||||||
|
"""
|
||||||
# @torch.jit.script
|
|
||||||
def wtac(distances, labels):
|
|
||||||
winning_indices = torch.min(distances, dim=1).indices
|
winning_indices = torch.min(distances, dim=1).indices
|
||||||
winning_labels = labels[winning_indices].squeeze()
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
return winning_labels
|
return winning_labels
|
||||||
|
|
||||||
|
|
||||||
# @torch.jit.script
|
def knnc(distances: torch.Tensor,
|
||||||
def knnc(distances, labels, k):
|
labels: torch.LongTensor,
|
||||||
winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
|
k: int = 1) -> (torch.LongTensor):
|
||||||
winning_labels = labels[winning_indices].squeeze()
|
"""K-Nearest-Neighbors-Competition.
|
||||||
|
|
||||||
|
Returns the labels corresponding to the winners.
|
||||||
|
|
||||||
|
"""
|
||||||
|
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
||||||
|
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
||||||
return winning_labels
|
return winning_labels
|
||||||
|
@@ -1,12 +1,9 @@
|
|||||||
"""ProtoTorch distance functions."""
|
"""ProtoTorch distance functions."""
|
||||||
|
|
||||||
import torch
|
|
||||||
from prototorch.functions.helper import (
|
|
||||||
equal_int_shape,
|
|
||||||
_int_and_mixed_shape,
|
|
||||||
_check_shapes,
|
|
||||||
)
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||||
|
equal_int_shape, get_flat)
|
||||||
|
|
||||||
|
|
||||||
def squared_euclidean_distance(x, y):
|
def squared_euclidean_distance(x, y):
|
||||||
@@ -14,12 +11,10 @@ def squared_euclidean_distance(x, y):
|
|||||||
|
|
||||||
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
||||||
|
|
||||||
:param `torch.tensor` x: Two dimensional vector
|
|
||||||
:param `torch.tensor` y: Two dimensional vector
|
|
||||||
|
|
||||||
**Alias:**
|
**Alias:**
|
||||||
``prototorch.functions.distances.sed``
|
``prototorch.functions.distances.sed``
|
||||||
"""
|
"""
|
||||||
|
x, y = get_flat(x, y)
|
||||||
expanded_x = x.unsqueeze(dim=1)
|
expanded_x = x.unsqueeze(dim=1)
|
||||||
batchwise_difference = y - expanded_x
|
batchwise_difference = y - expanded_x
|
||||||
differences_raised = torch.pow(batchwise_difference, 2)
|
differences_raised = torch.pow(batchwise_difference, 2)
|
||||||
@@ -32,30 +27,40 @@ def euclidean_distance(x, y):
|
|||||||
|
|
||||||
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
||||||
|
|
||||||
:param `torch.tensor` x: Input Tensor of shape :math:`X \times N`
|
|
||||||
:param `torch.tensor` y: Input Tensor of shape :math:`Y \times N`
|
|
||||||
|
|
||||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||||
:rtype: `torch.tensor`
|
:rtype: `torch.tensor`
|
||||||
"""
|
"""
|
||||||
|
x, y = get_flat(x, y)
|
||||||
distances_raised = squared_euclidean_distance(x, y)
|
distances_raised = squared_euclidean_distance(x, y)
|
||||||
distances = torch.sqrt(distances_raised)
|
distances = torch.sqrt(distances_raised)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_distance_v2(x, y):
|
||||||
|
x, y = get_flat(x, y)
|
||||||
|
diff = y - x.unsqueeze(1)
|
||||||
|
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||||
|
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||||
|
# batch diagonal. See:
|
||||||
|
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||||
|
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||||
|
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
||||||
|
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
||||||
|
# print(f"{distances.shape=}") # (nx, ny)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
def lpnorm_distance(x, y, p):
|
def lpnorm_distance(x, y, p):
|
||||||
r"""
|
r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
||||||
Calculates the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
|
||||||
Also known as Minkowski distance.
|
Also known as Minkowski distance.
|
||||||
|
|
||||||
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
||||||
|
|
||||||
Calls ``torch.cdist``
|
Calls ``torch.cdist``
|
||||||
|
|
||||||
:param `torch.tensor` x: Two dimensional vector
|
|
||||||
:param `torch.tensor` y: Two dimensional vector
|
|
||||||
:param p: p parameter of the lp norm
|
:param p: p parameter of the lp norm
|
||||||
"""
|
"""
|
||||||
|
x, y = get_flat(x, y)
|
||||||
distances = torch.cdist(x, y, p=p)
|
distances = torch.cdist(x, y, p=p)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
@@ -65,10 +70,9 @@ def omega_distance(x, y, omega):
|
|||||||
|
|
||||||
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
||||||
|
|
||||||
:param `torch.tensor` x: Two dimensional vector
|
|
||||||
:param `torch.tensor` y: Two dimensional vector
|
|
||||||
:param `torch.tensor` omega: Two dimensional matrix
|
:param `torch.tensor` omega: Two dimensional matrix
|
||||||
"""
|
"""
|
||||||
|
x, y = get_flat(x, y)
|
||||||
projected_x = x @ omega
|
projected_x = x @ omega
|
||||||
projected_y = y @ omega
|
projected_y = y @ omega
|
||||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||||
@@ -80,15 +84,14 @@ def lomega_distance(x, y, omegas):
|
|||||||
|
|
||||||
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
||||||
|
|
||||||
:param `torch.tensor` x: Two dimensional vector
|
|
||||||
:param `torch.tensor` y: Two dimensional vector
|
|
||||||
:param `torch.tensor` omegas: Three dimensional matrix
|
:param `torch.tensor` omegas: Three dimensional matrix
|
||||||
"""
|
"""
|
||||||
|
x, y = get_flat(x, y)
|
||||||
projected_x = x @ omegas
|
projected_x = x @ omegas
|
||||||
projected_y = torch.diagonal(y @ omegas).T
|
projected_y = torch.diagonal(y @ omegas).T
|
||||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||||
batchwise_difference = expanded_y - projected_x
|
batchwise_difference = expanded_y - projected_x
|
||||||
differences_squared = batchwise_difference ** 2
|
differences_squared = batchwise_difference**2
|
||||||
distances = torch.sum(differences_squared, dim=2)
|
distances = torch.sum(differences_squared, dim=2)
|
||||||
distances = distances.permute(1, 0)
|
distances = distances.permute(1, 0)
|
||||||
return distances
|
return distances
|
||||||
@@ -107,26 +110,18 @@ def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
|||||||
for tensor in [x, y]:
|
for tensor in [x, y]:
|
||||||
if tensor.ndim != 2:
|
if tensor.ndim != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The tensor dimension must be two. You provide: tensor.ndim="
|
"The tensor dimension must be two. You provide: tensor.ndim=" +
|
||||||
+ str(tensor.ndim)
|
str(tensor.ndim) + ".")
|
||||||
+ "."
|
|
||||||
)
|
|
||||||
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
||||||
+ str(tuple(x.shape)[1])
|
+ str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" +
|
||||||
+ " and tuple(y.shape)(y)[1]="
|
str(tuple(y.shape)[1]) + ".")
|
||||||
+ str(tuple(y.shape)[1])
|
|
||||||
+ "."
|
|
||||||
)
|
|
||||||
|
|
||||||
y = torch.transpose(y)
|
y = torch.transpose(y)
|
||||||
|
|
||||||
diss = (
|
diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
|
||||||
torch.sum(x ** 2, axis=1, keepdims=True)
|
torch.sum(y**2, axis=0, keepdims=True))
|
||||||
- 2 * torch.dot(x, y)
|
|
||||||
+ torch.sum(y ** 2, axis=0, keepdims=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not squared:
|
if not squared:
|
||||||
if epsilon == 0:
|
if epsilon == 0:
|
||||||
@@ -173,19 +168,18 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
if subspaces.ndim == 2:
|
if subspaces.ndim == 2:
|
||||||
# clean solution without map if the matrix_scope is global
|
# clean solution without map if the matrix_scope is global
|
||||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||||
subspaces, torch.transpose(subspaces)
|
subspaces, torch.transpose(subspaces))
|
||||||
)
|
|
||||||
|
|
||||||
projected_signals = torch.dot(signals, projectors)
|
projected_signals = torch.dot(signals, projectors)
|
||||||
projected_protos = torch.dot(protos, projectors)
|
projected_protos = torch.dot(protos, projectors)
|
||||||
|
|
||||||
diss = euclidean_distance_matrix(
|
diss = euclidean_distance_matrix(projected_signals,
|
||||||
projected_signals, projected_protos, squared=squared, epsilon=epsilon
|
projected_protos,
|
||||||
)
|
squared=squared,
|
||||||
|
epsilon=epsilon)
|
||||||
|
|
||||||
diss = torch.reshape(
|
diss = torch.reshape(
|
||||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
|
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||||
)
|
|
||||||
|
|
||||||
return torch.permute(diss, [0, 2, 1])
|
return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
@@ -193,21 +187,18 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# no solution without map possible --> memory efficient but slow!
|
# no solution without map possible --> memory efficient but slow!
|
||||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||||
subspaces, subspaces
|
subspaces,
|
||||||
) # K.batch_dot(subspaces, subspaces, [2, 2])
|
subspaces) # K.batch_dot(subspaces, subspaces, [2, 2])
|
||||||
|
|
||||||
projected_protos = (
|
projected_protos = (protos @ subspaces
|
||||||
protos @ subspaces
|
).T # K.batch_dot(projectors, protos, [1, 1]))
|
||||||
).T # K.batch_dot(projectors, protos, [1, 1]))
|
|
||||||
|
|
||||||
def projected_norm(projector):
|
def projected_norm(projector):
|
||||||
return torch.sum(torch.dot(signals, projector) ** 2, axis=1)
|
return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
||||||
|
|
||||||
diss = (
|
diss = (torch.transpose(map(projected_norm, projectors)) -
|
||||||
torch.transpose(map(projected_norm, projectors))
|
2 * torch.dot(signals, projected_protos) +
|
||||||
- 2 * torch.dot(signals, projected_protos)
|
torch.sum(projected_protos**2, axis=0, keepdims=True))
|
||||||
+ torch.sum(projected_protos ** 2, axis=0, keepdims=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not squared:
|
if not squared:
|
||||||
if epsilon == 0:
|
if epsilon == 0:
|
||||||
@@ -216,8 +207,7 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||||
|
|
||||||
diss = torch.reshape(
|
diss = torch.reshape(
|
||||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
|
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||||
)
|
|
||||||
|
|
||||||
return torch.permute(diss, [0, 2, 1])
|
return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
@@ -233,12 +223,12 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# Scope: Tangentspace Projections
|
# Scope: Tangentspace Projections
|
||||||
diff = torch.reshape(
|
diff = torch.reshape(
|
||||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
|
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||||
)
|
|
||||||
projected_diff = diff @ projectors
|
projected_diff = diff @ projectors
|
||||||
projected_diff = torch.reshape(
|
projected_diff = torch.reshape(
|
||||||
projected_diff,
|
projected_diff,
|
||||||
(signal_shape[0], signal_shape[2], signal_shape[1]) + signal_shape[3:],
|
(signal_shape[0], signal_shape[2], signal_shape[1]) +
|
||||||
|
signal_shape[3:],
|
||||||
)
|
)
|
||||||
|
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
@@ -251,13 +241,13 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# Scope: Tangentspace Projections
|
# Scope: Tangentspace Projections
|
||||||
diff = torch.reshape(
|
diff = torch.reshape(
|
||||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
|
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||||
)
|
|
||||||
diff = diff.permute([1, 0, 2])
|
diff = diff.permute([1, 0, 2])
|
||||||
projected_diff = torch.bmm(diff, projectors)
|
projected_diff = torch.bmm(diff, projectors)
|
||||||
projected_diff = torch.reshape(
|
projected_diff = torch.reshape(
|
||||||
projected_diff,
|
projected_diff,
|
||||||
(signal_shape[1], signal_shape[0], signal_shape[2]) + signal_shape[3:],
|
(signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||||
|
signal_shape[3:],
|
||||||
)
|
)
|
||||||
|
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
|
@@ -1,6 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_flat(*args):
|
||||||
|
rv = [x.view(x.size(0), -1) for x in args]
|
||||||
|
return rv
|
||||||
|
|
||||||
|
|
||||||
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
||||||
"""Computes the accuracy of a prototype based model.
|
"""Computes the accuracy of a prototype based model.
|
||||||
via Winner-Takes-All rule.
|
via Winner-Takes-All rule.
|
||||||
@@ -23,7 +28,7 @@ def predict_label(y_pred, plabels):
|
|||||||
|
|
||||||
def mixed_shape(inputs):
|
def mixed_shape(inputs):
|
||||||
if not torch.is_tensor(inputs):
|
if not torch.is_tensor(inputs):
|
||||||
raise ValueError('Input must be a tensor.')
|
raise ValueError("Input must be a tensor.")
|
||||||
else:
|
else:
|
||||||
int_shape = list(inputs.shape)
|
int_shape = list(inputs.shape)
|
||||||
# sometimes int_shape returns mixed integer types
|
# sometimes int_shape returns mixed integer types
|
||||||
@@ -39,11 +44,11 @@ def mixed_shape(inputs):
|
|||||||
def equal_int_shape(shape_1, shape_2):
|
def equal_int_shape(shape_1, shape_2):
|
||||||
if not isinstance(shape_1,
|
if not isinstance(shape_1,
|
||||||
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
||||||
raise ValueError('Input shapes must list or tuple.')
|
raise ValueError("Input shapes must list or tuple.")
|
||||||
for shape in [shape_1, shape_2]:
|
for shape in [shape_1, shape_2]:
|
||||||
if not all([isinstance(x, int) or x is None for x in shape]):
|
if not all([isinstance(x, int) or x is None for x in shape]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Input shapes must be list or tuple of int and None values.')
|
"Input shapes must be list or tuple of int and None values.")
|
||||||
|
|
||||||
if len(shape_1) != len(shape_2):
|
if len(shape_1) != len(shape_2):
|
||||||
return False
|
return False
|
||||||
|
@@ -15,59 +15,59 @@ def register_initializer(function):
|
|||||||
|
|
||||||
def labels_from(distribution, one_hot=True):
|
def labels_from(distribution, one_hot=True):
|
||||||
"""Takes a distribution tensor and returns a labels tensor."""
|
"""Takes a distribution tensor and returns a labels tensor."""
|
||||||
nclasses = distribution.shape[0]
|
num_classes = distribution.shape[0]
|
||||||
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
llist = [[i] * n for i, n in zip(range(num_classes), distribution)]
|
||||||
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
||||||
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
||||||
plabels = torch.tensor(flat_llist, requires_grad=False)
|
plabels = torch.tensor(flat_llist, requires_grad=False)
|
||||||
if one_hot:
|
if one_hot:
|
||||||
return torch.eye(nclasses)[plabels]
|
return torch.eye(num_classes)[plabels]
|
||||||
return plabels
|
return plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
protos = torch.ones(nprotos, *x_train.shape[1:])
|
protos = torch.ones(num_protos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
protos = torch.zeros(num_protos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
protos = torch.rand(nprotos, *x_train.shape[1:])
|
protos = torch.rand(num_protos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
protos = torch.randn(nprotos, *x_train.shape[1:])
|
protos = torch.randn(num_protos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
pdim = x_train.shape[1]
|
pdim = x_train.shape[1]
|
||||||
protos = torch.empty(nprotos, pdim)
|
protos = torch.empty(num_protos, pdim)
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
for i, label in enumerate(plabels):
|
for i, label in enumerate(plabels):
|
||||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||||
if one_hot:
|
if one_hot:
|
||||||
nclasses = y_train.size()[1]
|
num_classes = y_train.size()[1]
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
xl = x_train[matcher]
|
xl = x_train[matcher]
|
||||||
mean_xl = torch.mean(xl, dim=0)
|
mean_xl = torch.mean(xl, dim=0)
|
||||||
protos[i] = mean_xl
|
protos[i] = mean_xl
|
||||||
@@ -81,15 +81,15 @@ def stratified_random(x_train,
|
|||||||
prototype_distribution,
|
prototype_distribution,
|
||||||
one_hot=True,
|
one_hot=True,
|
||||||
epsilon=1e-7):
|
epsilon=1e-7):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
pdim = x_train.shape[1]
|
pdim = x_train.shape[1]
|
||||||
protos = torch.empty(nprotos, pdim)
|
protos = torch.empty(num_protos, pdim)
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
for i, label in enumerate(plabels):
|
for i, label in enumerate(plabels):
|
||||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||||
if one_hot:
|
if one_hot:
|
||||||
nclasses = y_train.size()[1]
|
num_classes = y_train.size()[1]
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
xl = x_train[matcher]
|
xl = x_train[matcher]
|
||||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||||
random_xl = xl[rand_index]
|
random_xl = xl[rand_index]
|
||||||
@@ -104,4 +104,4 @@ def get_initializer(funcname):
|
|||||||
return funcname
|
return funcname
|
||||||
if funcname in INITIALIZERS:
|
if funcname in INITIALIZERS:
|
||||||
return INITIALIZERS.get(funcname)
|
return INITIALIZERS.get(funcname)
|
||||||
raise NameError(f'Initializer {funcname} was not found.')
|
raise NameError(f"Initializer {funcname} was not found.")
|
||||||
|
@@ -3,20 +3,29 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _get_dp_dm(distances, targets, plabels):
|
def _get_matcher(targets, labels):
|
||||||
matcher = torch.eq(targets.unsqueeze(dim=1), plabels)
|
"""Returns a boolean tensor."""
|
||||||
if plabels.ndim == 2:
|
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
||||||
|
if labels.ndim == 2:
|
||||||
# if the labels are one-hot vectors
|
# if the labels are one-hot vectors
|
||||||
nclasses = targets.size()[1]
|
num_classes = targets.size()[1]
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dp_dm(distances, targets, plabels, with_indices=False):
|
||||||
|
"""Returns the d+ and d- values for a batch of distances."""
|
||||||
|
matcher = _get_matcher(targets, plabels)
|
||||||
not_matcher = torch.bitwise_not(matcher)
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
|
|
||||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||||
d_matching = torch.where(matcher, distances, inf)
|
d_matching = torch.where(matcher, distances, inf)
|
||||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||||
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
dp = torch.min(d_matching, dim=-1, keepdim=True)
|
||||||
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
|
dm = torch.min(d_unmatching, dim=-1, keepdim=True)
|
||||||
return dp, dm
|
if with_indices:
|
||||||
|
return dp, dm
|
||||||
|
return dp.values, dm.values
|
||||||
|
|
||||||
|
|
||||||
def glvq_loss(distances, target_labels, prototype_labels):
|
def glvq_loss(distances, target_labels, prototype_labels):
|
||||||
@@ -24,3 +33,62 @@ def glvq_loss(distances, target_labels, prototype_labels):
|
|||||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||||
mu = (dp - dm) / (dp + dm)
|
mu = (dp - dm) / (dp + dm)
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
def lvq1_loss(distances, target_labels, prototype_labels):
|
||||||
|
"""LVQ1 loss function with support for one-hot labels.
|
||||||
|
|
||||||
|
See Section 4 [Sado&Yamada]
|
||||||
|
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||||
|
"""
|
||||||
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||||
|
mu = dp
|
||||||
|
mu[dp > dm] = -dm[dp > dm]
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
def lvq21_loss(distances, target_labels, prototype_labels):
|
||||||
|
"""LVQ2.1 loss function with support for one-hot labels.
|
||||||
|
|
||||||
|
See Section 4 [Sado&Yamada]
|
||||||
|
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||||
|
"""
|
||||||
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||||
|
mu = dp - dm
|
||||||
|
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
# Probabilistic
|
||||||
|
def _get_class_probabilities(probabilities, targets, prototype_labels):
|
||||||
|
# Create Label Mapping
|
||||||
|
uniques = prototype_labels.unique(sorted=True).tolist()
|
||||||
|
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||||
|
|
||||||
|
target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist())))
|
||||||
|
|
||||||
|
whole = probabilities.sum(dim=1)
|
||||||
|
correct = probabilities[torch.arange(len(probabilities)), target_indices]
|
||||||
|
wrong = whole - correct
|
||||||
|
|
||||||
|
return whole, correct, wrong
|
||||||
|
|
||||||
|
|
||||||
|
def nllr_loss(probabilities, targets, prototype_labels):
|
||||||
|
"""Compute the Negative Log-Likelihood Ratio loss."""
|
||||||
|
_, correct, wrong = _get_class_probabilities(probabilities, targets,
|
||||||
|
prototype_labels)
|
||||||
|
|
||||||
|
likelihood = correct / wrong
|
||||||
|
log_likelihood = torch.log(likelihood)
|
||||||
|
return -1.0 * log_likelihood
|
||||||
|
|
||||||
|
|
||||||
|
def rslvq_loss(probabilities, targets, prototype_labels):
|
||||||
|
"""Compute the Robust Soft Learning Vector Quantization (RSLVQ) loss."""
|
||||||
|
whole, correct, _ = _get_class_probabilities(probabilities, targets,
|
||||||
|
prototype_labels)
|
||||||
|
|
||||||
|
likelihood = correct / whole
|
||||||
|
log_likelihood = torch.log(likelihood)
|
||||||
|
return -1.0 * log_likelihood
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
80
prototorch/functions/pooling.py
Normal file
80
prototorch/functions/pooling.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""ProtoTorch pooling functions."""
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def stratify_with(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor,
|
||||||
|
fn: Callable,
|
||||||
|
fill_value: float = 0.0) -> (torch.Tensor):
|
||||||
|
"""Apply an arbitrary stratification strategy on the columns on `values`.
|
||||||
|
|
||||||
|
The outputs correspond to sorted labels.
|
||||||
|
"""
|
||||||
|
clabels = torch.unique(labels, dim=0, sorted=True)
|
||||||
|
num_classes = clabels.size()[0]
|
||||||
|
if values.size()[1] == num_classes:
|
||||||
|
# skip if stratification is trivial
|
||||||
|
return values
|
||||||
|
batch_size = values.size()[0]
|
||||||
|
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
||||||
|
filler = torch.full_like(values.T, fill_value=fill_value)
|
||||||
|
for i, cl in enumerate(clabels):
|
||||||
|
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# if the labels are one-hot vectors
|
||||||
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
|
cdists = torch.where(matcher, values.T, filler).T
|
||||||
|
winning_values[i] = fn(cdists)
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# Transpose to return with `batch_size` first and
|
||||||
|
# reverse the columns to fix the ordering of the classes
|
||||||
|
return torch.flip(winning_values.T, dims=(1, ))
|
||||||
|
|
||||||
|
return winning_values.T # return with `batch_size` first
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_sum_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise sum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
|
||||||
|
fill_value=0.0)
|
||||||
|
return winning_values
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_min_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise minimum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
|
||||||
|
fill_value=float("inf"))
|
||||||
|
return winning_values
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_max_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise maximum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
|
||||||
|
fill_value=-1.0 * float("inf"))
|
||||||
|
return winning_values
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_prod_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise maximum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
|
||||||
|
fill_value=1.0)
|
||||||
|
return winning_values
|
18
prototorch/functions/similarities.py
Normal file
18
prototorch/functions/similarities.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""ProtoTorch similarity functions."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(x, y):
|
||||||
|
"""Compute the cosine similarity between :math:`x` and :math:`y`.
|
||||||
|
|
||||||
|
Expected dimension of x is 2.
|
||||||
|
Expected dimension of y is 2.
|
||||||
|
"""
|
||||||
|
norm_x = x.pow(2).sum(1).sqrt()
|
||||||
|
norm_y = y.pow(2).sum(1).sqrt()
|
||||||
|
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
||||||
|
epsilon = torch.finfo(norm_mat.dtype).eps
|
||||||
|
norm_mat.clamp_(min=epsilon)
|
||||||
|
similarities = (x @ y.T) / norm_mat
|
||||||
|
return similarities
|
5
prototorch/functions/transforms.py
Normal file
5
prototorch/functions/transforms.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian(distance, variance):
|
||||||
|
return torch.exp(-(distance * distance) / (2 * variance))
|
@@ -1,7 +1,7 @@
|
|||||||
"""ProtoTorch modules."""
|
"""ProtoTorch modules."""
|
||||||
|
|
||||||
from .prototypes import Prototypes1D
|
from .competitions import *
|
||||||
|
from .initializers import *
|
||||||
__all__ = [
|
from .pooling import *
|
||||||
'Prototypes1D',
|
from .transformations import *
|
||||||
]
|
from .wrappers import LambdaLayer, LossLayer
|
||||||
|
41
prototorch/modules/competitions.py
Normal file
41
prototorch/modules/competitions.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""ProtoTorch Competition Modules."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.functions.competitions import knnc, wtac
|
||||||
|
|
||||||
|
|
||||||
|
class WTAC(torch.nn.Module):
|
||||||
|
"""Winner-Takes-All-Competition Layer.
|
||||||
|
|
||||||
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def forward(self, distances, labels):
|
||||||
|
return wtac(distances, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class LTAC(torch.nn.Module):
|
||||||
|
"""Loser-Takes-All-Competition Layer.
|
||||||
|
|
||||||
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def forward(self, probs, labels):
|
||||||
|
return wtac(-1.0 * probs, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class KNNC(torch.nn.Module):
|
||||||
|
"""K-Nearest-Neighbors-Competition.
|
||||||
|
|
||||||
|
Thin wrapper over the `knnc` function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, k=1, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
def forward(self, distances, labels):
|
||||||
|
return knnc(distances, labels, k=self.k)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"k: {self.k}"
|
61
prototorch/modules/initializers.py
Normal file
61
prototorch/modules/initializers.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""ProtoTroch Module Initializers."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Transformations
|
||||||
|
class MatrixInitializer(object):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
...
|
||||||
|
|
||||||
|
def generate(self, shape):
|
||||||
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosInitializer(MatrixInitializer):
|
||||||
|
def generate(self, shape):
|
||||||
|
return torch.zeros(shape)
|
||||||
|
|
||||||
|
|
||||||
|
class OnesInitializer(MatrixInitializer):
|
||||||
|
def __init__(self, scale=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def generate(self, shape):
|
||||||
|
return torch.ones(shape) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class UniformInitializer(MatrixInitializer):
|
||||||
|
def __init__(self, minimum=0.0, maximum=1.0, scale=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.minimum = minimum
|
||||||
|
self.maximum = maximum
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def generate(self, shape):
|
||||||
|
return torch.ones(shape).uniform_(self.minimum,
|
||||||
|
self.maximum) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class DataAwareInitializer(MatrixInitializer):
|
||||||
|
def __init__(self, data, transform=torch.nn.Identity()):
|
||||||
|
super().__init__()
|
||||||
|
self.data = data
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self.data
|
||||||
|
|
||||||
|
|
||||||
|
class EigenVectorInitializer(DataAwareInitializer):
|
||||||
|
def generate(self, shape):
|
||||||
|
# TODO
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
# Aliases
|
||||||
|
EV = EigenVectorInitializer
|
||||||
|
Random = RandomInitializer = UniformInitializer
|
||||||
|
Zeros = ZerosInitializer
|
||||||
|
Ones = OnesInitializer
|
@@ -1,13 +1,12 @@
|
|||||||
"""ProtoTorch losses."""
|
"""ProtoTorch losses."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.activations import get_activation
|
from prototorch.functions.activations import get_activation
|
||||||
from prototorch.functions.losses import glvq_loss
|
from prototorch.functions.losses import glvq_loss
|
||||||
|
|
||||||
|
|
||||||
class GLVQLoss(torch.nn.Module):
|
class GLVQLoss(torch.nn.Module):
|
||||||
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.squashing = get_activation(squashing)
|
self.squashing = get_activation(squashing)
|
||||||
@@ -18,3 +17,42 @@ class GLVQLoss(torch.nn.Module):
|
|||||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||||
return torch.sum(batch_loss, dim=0)
|
return torch.sum(batch_loss, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class NeuralGasEnergy(torch.nn.Module):
|
||||||
|
def __init__(self, lm, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.lm = lm
|
||||||
|
|
||||||
|
def forward(self, d):
|
||||||
|
order = torch.argsort(d, dim=1)
|
||||||
|
ranks = torch.argsort(order, dim=1)
|
||||||
|
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
|
||||||
|
|
||||||
|
return cost, order
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"lambda: {self.lm}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _nghood_fn(rankings, lm):
|
||||||
|
return torch.exp(-rankings / lm)
|
||||||
|
|
||||||
|
|
||||||
|
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
||||||
|
def __init__(self, topology_layer, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.topology_layer = topology_layer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _nghood_fn(rankings, topology):
|
||||||
|
winner = rankings[:, 0]
|
||||||
|
|
||||||
|
weights = torch.zeros_like(rankings, dtype=torch.float)
|
||||||
|
weights[torch.arange(rankings.shape[0]), winner] = 1.0
|
||||||
|
|
||||||
|
neighbours = topology.get_neighbours(winner)
|
||||||
|
|
||||||
|
weights[neighbours] = 0.1
|
||||||
|
|
||||||
|
return weights
|
||||||
|
@@ -1,9 +1,8 @@
|
|||||||
from torch import nn
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
from prototorch.functions.distances import euclidean_distance_matrix
|
||||||
from prototorch.functions.normalization import orthogonalization
|
from prototorch.functions.normalization import orthogonalization
|
||||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
class GTLVQ(nn.Module):
|
class GTLVQ(nn.Module):
|
||||||
@@ -71,50 +70,42 @@ class GTLVQ(nn.Module):
|
|||||||
subspace_data=None,
|
subspace_data=None,
|
||||||
prototype_data=None,
|
prototype_data=None,
|
||||||
subspace_size=256,
|
subspace_size=256,
|
||||||
tangent_projection_type='local',
|
tangent_projection_type="local",
|
||||||
prototypes_per_class=2,
|
prototypes_per_class=2,
|
||||||
feature_dim=256,
|
feature_dim=256,
|
||||||
):
|
):
|
||||||
super(GTLVQ, self).__init__()
|
super(GTLVQ, self).__init__()
|
||||||
|
|
||||||
self.num_protos = num_classes * prototypes_per_class
|
self.num_protos = num_classes * prototypes_per_class
|
||||||
|
self.num_protos_class = prototypes_per_class
|
||||||
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
||||||
self.feature_dim = feature_dim
|
self.feature_dim = feature_dim
|
||||||
|
self.num_classes = num_classes
|
||||||
|
|
||||||
|
cls_initializer = StratifiedMeanInitializer(prototype_data)
|
||||||
|
cls_distribution = {
|
||||||
|
"num_classes": num_classes,
|
||||||
|
"prototypes_per_class": prototypes_per_class,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.cls = LabeledComponents(cls_distribution, cls_initializer)
|
||||||
|
|
||||||
if subspace_data is None:
|
if subspace_data is None:
|
||||||
raise ValueError('Init Data must be specified!')
|
raise ValueError("Init Data must be specified!")
|
||||||
|
|
||||||
self.tpt = tangent_projection_type
|
self.tpt = tangent_projection_type
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.tpt == 'local' or self.tpt == 'local_proj':
|
if self.tpt == "local":
|
||||||
self.init_local_subspace(subspace_data)
|
self.init_local_subspace(subspace_data, subspace_size,
|
||||||
elif self.tpt == 'global':
|
self.num_protos)
|
||||||
|
elif self.tpt == "global":
|
||||||
self.init_gobal_subspace(subspace_data, subspace_size)
|
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||||
else:
|
else:
|
||||||
self.subspaces = None
|
self.subspaces = None
|
||||||
|
|
||||||
# Hypothesis-Margin-Classifier
|
|
||||||
self.cls = Prototypes1D(input_dim=feature_dim,
|
|
||||||
prototypes_per_class=prototypes_per_class,
|
|
||||||
nclasses=num_classes,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=prototype_data)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Tangent Projection
|
if self.tpt == "local":
|
||||||
if self.tpt == 'local_proj':
|
dis = self.local_tangent_distances(x)
|
||||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
|
||||||
1).unsqueeze(2)
|
|
||||||
dis, proj_x = self.local_tangent_projection(x_conform)
|
|
||||||
|
|
||||||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
|
||||||
self.feature_dim)
|
|
||||||
return proj_x, dis
|
|
||||||
elif self.tpt == "local":
|
|
||||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
|
||||||
1).unsqueeze(2)
|
|
||||||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
|
||||||
self.subspaces)
|
|
||||||
elif self.tpt == "gloabl":
|
elif self.tpt == "gloabl":
|
||||||
dis = self.global_tangent_distances(x)
|
dis = self.global_tangent_distances(x)
|
||||||
else:
|
else:
|
||||||
@@ -127,55 +118,44 @@ class GTLVQ(nn.Module):
|
|||||||
_, _, v = torch.svd(data)
|
_, _, v = torch.svd(data)
|
||||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||||
subspaces = subspace[:, :num_subspaces]
|
subspaces = subspace[:, :num_subspaces]
|
||||||
self.subspaces = torch.nn.Parameter(
|
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||||
subspaces).clone().detach().requires_grad_(True)
|
|
||||||
|
|
||||||
def init_local_subspace(self, data):
|
def init_local_subspace(self, data, num_subspaces, num_protos):
|
||||||
_, _, v = torch.svd(data)
|
data = data - torch.mean(data, dim=0)
|
||||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
_, _, v = torch.svd(data, some=False)
|
||||||
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
v = v[:, :num_subspaces]
|
||||||
self.num_protos, 0)
|
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
|
||||||
self.subspaces = torch.nn.Parameter(
|
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||||
subspaces).clone().detach().requires_grad_(True)
|
|
||||||
|
|
||||||
def global_tangent_distances(self, x):
|
def global_tangent_distances(self, x):
|
||||||
# Tangent Projection
|
# Tangent Projection
|
||||||
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
|
x, projected_prototypes = (
|
||||||
|
x @ self.subspaces,
|
||||||
|
self.cls.prototypes @ self.subspaces,
|
||||||
|
)
|
||||||
# Euclidean Distance
|
# Euclidean Distance
|
||||||
return euclidean_distance_matrix(x, projected_prototypes)
|
return euclidean_distance_matrix(x, projected_prototypes)
|
||||||
|
|
||||||
def local_tangent_projection(self,
|
def local_tangent_distances(self, x):
|
||||||
signals):
|
|
||||||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
|
||||||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
|
||||||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
|
||||||
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
|
||||||
# subspace should be orthogonalized
|
|
||||||
# Origin Source Code
|
|
||||||
# Origin Author:
|
|
||||||
protos = self.cls.prototypes
|
|
||||||
subspaces = self.subspaces
|
|
||||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
|
||||||
_, proto_int_shape = _int_and_mixed_shape(protos)
|
|
||||||
|
|
||||||
# check if the shapes are correct
|
# Tangent Distance
|
||||||
_check_shapes(signal_int_shape, proto_int_shape)
|
x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components,
|
||||||
|
x.size(-1))
|
||||||
# Tangent Data Projections
|
protos = self.cls()[0].unsqueeze(0).expand(x.size(0),
|
||||||
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
|
self.cls.num_components,
|
||||||
data = signals.squeeze(2).permute([1, 0, 2])
|
x.size(-1))
|
||||||
projected_data = torch.bmm(data, subspaces)
|
projectors = torch.eye(
|
||||||
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
|
self.subspaces.shape[-2], device=x.device) - torch.bmm(
|
||||||
diff = projected_data - projected_protos
|
self.subspaces, self.subspaces.permute([0, 2, 1]))
|
||||||
projected_diff = torch.reshape(
|
diff = (x - protos)
|
||||||
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
diff = diff.permute([1, 0, 2])
|
||||||
signal_shape[3:])
|
diff = torch.bmm(diff, projectors)
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
diff = torch.norm(diff, 2, dim=-1).T
|
||||||
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
|
return diff
|
||||||
|
|
||||||
def get_parameters(self):
|
def get_parameters(self):
|
||||||
return {
|
return {
|
||||||
"params": self.cls.prototypes,
|
"params": self.cls.components,
|
||||||
}, {
|
}, {
|
||||||
"params": self.subspaces
|
"params": self.subspaces
|
||||||
}
|
}
|
||||||
@@ -183,8 +163,7 @@ class GTLVQ(nn.Module):
|
|||||||
def orthogonalize_subspace(self):
|
def orthogonalize_subspace(self):
|
||||||
if self.subspaces is not None:
|
if self.subspaces is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ortho_subpsaces = orthogonalization(
|
ortho_subpsaces = (orthogonalization(self.subspaces)
|
||||||
self.subspaces
|
if self.tpt == "global" else
|
||||||
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
|
torch.nn.init.orthogonal_(self.subspaces))
|
||||||
self.subspaces)
|
|
||||||
self.subspaces.copy_(ortho_subpsaces)
|
self.subspaces.copy_(ortho_subpsaces)
|
||||||
|
31
prototorch/modules/pooling.py
Normal file
31
prototorch/modules/pooling.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""ProtoTorch Pooling Modules."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.functions.pooling import (stratified_max_pooling,
|
||||||
|
stratified_min_pooling,
|
||||||
|
stratified_prod_pooling,
|
||||||
|
stratified_sum_pooling)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedSumPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_sum_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedProdPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_prod_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedMinPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_min_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedMaxPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_max_pooling(values, labels)
|
@@ -1,132 +0,0 @@
|
|||||||
"""ProtoTorch prototype modules."""
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from prototorch.functions.initializers import get_initializer
|
|
||||||
|
|
||||||
|
|
||||||
class _Prototypes(torch.nn.Module):
|
|
||||||
"""Abstract prototypes class."""
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def _validate_prototype_distribution(self):
|
|
||||||
if 0 in self.prototype_distribution:
|
|
||||||
warnings.warn("Are you sure about the `0` in "
|
|
||||||
"`prototype_distribution`?")
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"prototypes.shape: {tuple(self.prototypes.shape)}"
|
|
||||||
|
|
||||||
def forward(self):
|
|
||||||
return self.prototypes, self.prototype_labels
|
|
||||||
|
|
||||||
|
|
||||||
class Prototypes1D(_Prototypes):
|
|
||||||
"""Create a learnable set of one-dimensional prototypes.
|
|
||||||
|
|
||||||
TODO Complete this doc-string.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="ones",
|
|
||||||
prototype_distribution=None,
|
|
||||||
data=None,
|
|
||||||
dtype=torch.float32,
|
|
||||||
one_hot_labels=False,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
# Convert tensors to python lists before processing
|
|
||||||
if prototype_distribution is not None:
|
|
||||||
if not isinstance(prototype_distribution, list):
|
|
||||||
prototype_distribution = prototype_distribution.tolist()
|
|
||||||
|
|
||||||
if data is None:
|
|
||||||
if "input_dim" not in kwargs:
|
|
||||||
raise NameError("`input_dim` required if "
|
|
||||||
"no `data` is provided.")
|
|
||||||
if prototype_distribution:
|
|
||||||
kwargs_nclasses = sum(prototype_distribution)
|
|
||||||
else:
|
|
||||||
if "nclasses" not in kwargs:
|
|
||||||
raise NameError("`prototype_distribution` required if "
|
|
||||||
"both `data` and `nclasses` are not "
|
|
||||||
"provided.")
|
|
||||||
kwargs_nclasses = kwargs.pop("nclasses")
|
|
||||||
input_dim = kwargs.pop("input_dim")
|
|
||||||
if prototype_initializer in [
|
|
||||||
"stratified_mean", "stratified_random"
|
|
||||||
]:
|
|
||||||
warnings.warn(
|
|
||||||
f"`prototype_initializer`: `{prototype_initializer}` "
|
|
||||||
"requires `data`, but `data` is not provided. "
|
|
||||||
"Using randomly generated data instead.")
|
|
||||||
x_train = torch.rand(kwargs_nclasses, input_dim)
|
|
||||||
y_train = torch.arange(kwargs_nclasses)
|
|
||||||
if one_hot_labels:
|
|
||||||
y_train = torch.eye(kwargs_nclasses)[y_train]
|
|
||||||
data = [x_train, y_train]
|
|
||||||
|
|
||||||
x_train, y_train = data
|
|
||||||
x_train = torch.as_tensor(x_train).type(dtype)
|
|
||||||
y_train = torch.as_tensor(y_train).type(torch.int)
|
|
||||||
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
|
||||||
|
|
||||||
if nclasses == 1:
|
|
||||||
warnings.warn("Are you sure about having one class only?")
|
|
||||||
|
|
||||||
if x_train.ndim != 2:
|
|
||||||
raise ValueError("`data[0].ndim != 2`.")
|
|
||||||
|
|
||||||
if y_train.ndim == 2:
|
|
||||||
if y_train.shape[1] == 1 and one_hot_labels:
|
|
||||||
raise ValueError("`one_hot_labels` is set to `True` "
|
|
||||||
"but target labels are not one-hot-encoded.")
|
|
||||||
if y_train.shape[1] != 1 and not one_hot_labels:
|
|
||||||
raise ValueError("`one_hot_labels` is set to `False` "
|
|
||||||
"but target labels in `data` "
|
|
||||||
"are one-hot-encoded.")
|
|
||||||
if y_train.ndim == 1 and one_hot_labels:
|
|
||||||
raise ValueError("`one_hot_labels` is set to `True` "
|
|
||||||
"but target labels are not one-hot-encoded.")
|
|
||||||
|
|
||||||
# Verify input dimension if `input_dim` is provided
|
|
||||||
if "input_dim" in kwargs:
|
|
||||||
input_dim = kwargs.pop("input_dim")
|
|
||||||
if input_dim != x_train.shape[1]:
|
|
||||||
raise ValueError(f"Provided `input_dim`={input_dim} does "
|
|
||||||
"not match data dimension "
|
|
||||||
f"`data[0].shape[1]`={x_train.shape[1]}")
|
|
||||||
|
|
||||||
# Verify the number of classes if `nclasses` is provided
|
|
||||||
if "nclasses" in kwargs:
|
|
||||||
kwargs_nclasses = kwargs.pop("nclasses")
|
|
||||||
if kwargs_nclasses != nclasses:
|
|
||||||
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
|
|
||||||
"not match data labels "
|
|
||||||
"`torch.unique(data[1]).shape[0]`"
|
|
||||||
f"={nclasses}")
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
if not prototype_distribution:
|
|
||||||
prototype_distribution = [prototypes_per_class] * nclasses
|
|
||||||
with torch.no_grad():
|
|
||||||
self.prototype_distribution = torch.tensor(prototype_distribution)
|
|
||||||
|
|
||||||
self._validate_prototype_distribution()
|
|
||||||
|
|
||||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
|
||||||
prototypes, prototype_labels = self.prototype_initializer(
|
|
||||||
x_train,
|
|
||||||
y_train,
|
|
||||||
prototype_distribution=self.prototype_distribution,
|
|
||||||
one_hot=one_hot_labels,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register module parameters
|
|
||||||
self.prototypes = torch.nn.Parameter(prototypes)
|
|
||||||
self.prototype_labels = torch.nn.Parameter(
|
|
||||||
prototype_labels.type(dtype)).requires_grad_(False)
|
|
49
prototorch/modules/transformations.py
Normal file
49
prototorch/modules/transformations.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""ProtoTorch Transformation Layers."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from .initializers import MatrixInitializer
|
||||||
|
|
||||||
|
|
||||||
|
def _precheck_initializer(initializer):
|
||||||
|
if not isinstance(initializer, MatrixInitializer):
|
||||||
|
emsg = f"`initializer` has to be some subtype of " \
|
||||||
|
f"{MatrixInitializer}. " \
|
||||||
|
f"You have provided: {initializer=} instead."
|
||||||
|
raise TypeError(emsg)
|
||||||
|
|
||||||
|
|
||||||
|
class Omega(torch.nn.Module):
|
||||||
|
"""The Omega mapping used in GMLVQ."""
|
||||||
|
def __init__(self,
|
||||||
|
num_replicas=1,
|
||||||
|
input_dim=None,
|
||||||
|
latent_dim=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_weights=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if initialized_weights is not None:
|
||||||
|
self._register_weights(initialized_weights)
|
||||||
|
else:
|
||||||
|
if num_replicas == 1:
|
||||||
|
shape = (input_dim, latent_dim)
|
||||||
|
else:
|
||||||
|
shape = (num_replicas, input_dim, latent_dim)
|
||||||
|
self._initialize_weights(shape, initializer)
|
||||||
|
|
||||||
|
def _register_weights(self, weights):
|
||||||
|
self.register_parameter("_omega", Parameter(weights))
|
||||||
|
|
||||||
|
def _initialize_weights(self, shape, initializer):
|
||||||
|
_precheck_initializer(initializer)
|
||||||
|
_omega = initializer.generate(shape)
|
||||||
|
self._register_weights(_omega)
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self._omega
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
36
prototorch/modules/wrappers.py
Normal file
36
prototorch/modules/wrappers.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""ProtoTorch Wrappers."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaLayer(torch.nn.Module):
|
||||||
|
def __init__(self, fn, name=None):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
self.name = name or fn.__name__ # lambda fns get <lambda>
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.fn(*args, **kwargs)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
class LossLayer(torch.nn.modules.loss._Loss):
|
||||||
|
def __init__(self,
|
||||||
|
fn,
|
||||||
|
name=None,
|
||||||
|
size_average=None,
|
||||||
|
reduce=None,
|
||||||
|
reduction: str = "mean") -> None:
|
||||||
|
super().__init__(size_average=size_average,
|
||||||
|
reduce=reduce,
|
||||||
|
reduction=reduction)
|
||||||
|
self.fn = fn
|
||||||
|
self.name = name or fn.__name__ # lambda fns get <lambda>
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.fn(*args, **kwargs)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return self.name
|
@@ -1 +0,0 @@
|
|||||||
from .colors import color_scheme, get_legend_handles
|
|
||||||
|
46
prototorch/utils/celluloid.py
Normal file
46
prototorch/utils/celluloid.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from matplotlib.animation import ArtistAnimation
|
||||||
|
from matplotlib.artist import Artist
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
__version__ = "0.2.0"
|
||||||
|
|
||||||
|
|
||||||
|
class Camera:
|
||||||
|
"""Make animations easier."""
|
||||||
|
def __init__(self, figure: Figure) -> None:
|
||||||
|
"""Create camera from matplotlib figure."""
|
||||||
|
self._figure = figure
|
||||||
|
# need to keep track off artists for each axis
|
||||||
|
self._offsets: Dict[str, Dict[int, int]] = {
|
||||||
|
k: defaultdict(int)
|
||||||
|
for k in
|
||||||
|
["collections", "patches", "lines", "texts", "artists", "images"]
|
||||||
|
}
|
||||||
|
self._photos: List[List[Artist]] = []
|
||||||
|
|
||||||
|
def snap(self) -> List[Artist]:
|
||||||
|
"""Capture current state of the figure."""
|
||||||
|
frame_artists: List[Artist] = []
|
||||||
|
for i, axis in enumerate(self._figure.axes):
|
||||||
|
if axis.legend_ is not None:
|
||||||
|
axis.add_artist(axis.legend_)
|
||||||
|
for name in self._offsets:
|
||||||
|
new_artists = getattr(axis, name)[self._offsets[name][i]:]
|
||||||
|
frame_artists += new_artists
|
||||||
|
self._offsets[name][i] += len(new_artists)
|
||||||
|
self._photos.append(frame_artists)
|
||||||
|
return frame_artists
|
||||||
|
|
||||||
|
def animate(self, *args, **kwargs) -> ArtistAnimation:
|
||||||
|
"""Animate the snapshots taken.
|
||||||
|
Uses matplotlib.animation.ArtistAnimation
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ArtistAnimation
|
||||||
|
"""
|
||||||
|
return ArtistAnimation(self._figure, self._photos, *args, **kwargs)
|
@@ -1,13 +1,14 @@
|
|||||||
"""ProtoFlow color utilities."""
|
"""ProtoFlow color utilities."""
|
||||||
|
|
||||||
from matplotlib import cm
|
|
||||||
from matplotlib.colors import Normalize
|
|
||||||
from matplotlib.colors import to_hex
|
|
||||||
from matplotlib.colors import to_rgb
|
|
||||||
import matplotlib.lines as mlines
|
import matplotlib.lines as mlines
|
||||||
|
from matplotlib import cm
|
||||||
|
from matplotlib.colors import Normalize, to_hex, to_rgb
|
||||||
|
|
||||||
|
|
||||||
def color_scheme(n, cmap="viridis", form="hex", tikz=False,
|
def color_scheme(n,
|
||||||
|
cmap="viridis",
|
||||||
|
form="hex",
|
||||||
|
tikz=False,
|
||||||
zero_indexed=False):
|
zero_indexed=False):
|
||||||
"""Return *n* colors from the color scheme.
|
"""Return *n* colors from the color scheme.
|
||||||
|
|
||||||
@@ -57,13 +58,16 @@ def get_legend_handles(labels, marker="dots", zero_indexed=False):
|
|||||||
zero_indexed=zero_indexed)
|
zero_indexed=zero_indexed)
|
||||||
for label, color in zip(labels, colors.values()):
|
for label, color in zip(labels, colors.values()):
|
||||||
if marker == "dots":
|
if marker == "dots":
|
||||||
handle = mlines.Line2D([], [],
|
handle = mlines.Line2D(
|
||||||
color="white",
|
[],
|
||||||
markerfacecolor=color,
|
[],
|
||||||
marker="o",
|
color="white",
|
||||||
markersize=10,
|
markerfacecolor=color,
|
||||||
markeredgecolor="k",
|
marker="o",
|
||||||
label=label)
|
markersize=10,
|
||||||
|
markeredgecolor="k",
|
||||||
|
label=label,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
handle = mlines.Line2D([], [],
|
handle = mlines.Line2D([], [],
|
||||||
color=color,
|
color=color,
|
||||||
|
@@ -1,5 +0,0 @@
|
|||||||
matplotlib==3.1.2
|
|
||||||
pytest==5.3.4
|
|
||||||
requests==2.22.0
|
|
||||||
codecov==2.0.22
|
|
||||||
tqdm==4.44.1
|
|
26
setup.py
26
setup.py
@@ -8,8 +8,7 @@
|
|||||||
|
|
||||||
ProtoTorch Core Package
|
ProtoTorch Core Package
|
||||||
"""
|
"""
|
||||||
from setuptools import setup
|
from setuptools import find_packages, setup
|
||||||
from setuptools import find_packages
|
|
||||||
|
|
||||||
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
||||||
@@ -21,28 +20,30 @@ INSTALL_REQUIRES = [
|
|||||||
"torch>=1.3.1",
|
"torch>=1.3.1",
|
||||||
"torchvision>=0.5.0",
|
"torchvision>=0.5.0",
|
||||||
"numpy>=1.9.1",
|
"numpy>=1.9.1",
|
||||||
]
|
"sklearn",
|
||||||
DOCS = [
|
|
||||||
"recommonmark",
|
|
||||||
"sphinx",
|
|
||||||
"sphinx_rtd_theme",
|
|
||||||
"sphinxcontrib-katex",
|
|
||||||
]
|
]
|
||||||
DATASETS = [
|
DATASETS = [
|
||||||
"requests",
|
"requests",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
]
|
]
|
||||||
|
DEV = ["bumpversion"]
|
||||||
|
DOCS = [
|
||||||
|
"recommonmark",
|
||||||
|
"sphinx",
|
||||||
|
"sphinx_rtd_theme",
|
||||||
|
"sphinxcontrib-katex",
|
||||||
|
"sphinx-autodoc-typehints",
|
||||||
|
]
|
||||||
EXAMPLES = [
|
EXAMPLES = [
|
||||||
"sklearn",
|
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"torchinfo",
|
"torchinfo",
|
||||||
]
|
]
|
||||||
TESTS = ["pytest"]
|
TESTS = ["codecov", "pytest"]
|
||||||
ALL = DOCS + DATASETS + EXAMPLES + TESTS
|
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="prototorch",
|
name="prototorch",
|
||||||
version="0.3.0-dev0",
|
version="0.5.0",
|
||||||
description="Highly extensible, GPU-supported "
|
description="Highly extensible, GPU-supported "
|
||||||
"Learning Vector Quantization (LVQ) toolbox "
|
"Learning Vector Quantization (LVQ) toolbox "
|
||||||
"built using PyTorch and its nn API.",
|
"built using PyTorch and its nn API.",
|
||||||
@@ -72,6 +73,7 @@ setup(
|
|||||||
"Programming Language :: Python :: 3.6",
|
"Programming Language :: Python :: 3.6",
|
||||||
"Programming Language :: Python :: 3.7",
|
"Programming Language :: Python :: 3.7",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Software Development :: Libraries",
|
||||||
|
25
tests/test_components.py
Normal file
25
tests/test_components.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""ProtoTorch components test suite."""
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def test_labcomps_zeros_init():
|
||||||
|
protos = torch.zeros(3, 2)
|
||||||
|
c = pt.components.LabeledComponents(
|
||||||
|
distribution=[1, 1, 1],
|
||||||
|
initializer=pt.components.Zeros(2),
|
||||||
|
)
|
||||||
|
assert (c.components == protos).any() == True
|
||||||
|
|
||||||
|
|
||||||
|
def test_labcomps_warmstart():
|
||||||
|
protos = torch.randn(3, 2)
|
||||||
|
plabels = torch.tensor([1, 2, 3])
|
||||||
|
c = pt.components.LabeledComponents(
|
||||||
|
distribution=[1, 1, 1],
|
||||||
|
initializer=None,
|
||||||
|
initialized_components=[protos, plabels],
|
||||||
|
)
|
||||||
|
assert (c.components == protos).any() == True
|
||||||
|
assert (c.component_labels == plabels).any() == True
|
@@ -12,26 +12,26 @@ from prototorch.datasets import abstract, tecator
|
|||||||
class TestAbstract(unittest.TestCase):
|
class TestAbstract(unittest.TestCase):
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.Dataset('./artifacts')[0]
|
abstract.Dataset("./artifacts")[0]
|
||||||
|
|
||||||
def test_len(self):
|
def test_len(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
len(abstract.Dataset('./artifacts'))
|
len(abstract.Dataset("./artifacts"))
|
||||||
|
|
||||||
|
|
||||||
class TestProtoDataset(unittest.TestCase):
|
class TestProtoDataset(unittest.TestCase):
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.ProtoDataset('./artifacts')[0]
|
abstract.ProtoDataset("./artifacts")[0]
|
||||||
|
|
||||||
def test_download(self):
|
def test_download(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.ProtoDataset('./artifacts').download()
|
abstract.ProtoDataset("./artifacts").download()
|
||||||
|
|
||||||
|
|
||||||
class TestTecator(unittest.TestCase):
|
class TestTecator(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.artifacts_dir = './artifacts/Tecator'
|
self.artifacts_dir = "./artifacts/Tecator"
|
||||||
self._remove_artifacts()
|
self._remove_artifacts()
|
||||||
|
|
||||||
def _remove_artifacts(self):
|
def _remove_artifacts(self):
|
||||||
@@ -39,23 +39,23 @@ class TestTecator(unittest.TestCase):
|
|||||||
shutil.rmtree(self.artifacts_dir)
|
shutil.rmtree(self.artifacts_dir)
|
||||||
|
|
||||||
def test_download_false(self):
|
def test_download_false(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
self._remove_artifacts()
|
self._remove_artifacts()
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
_ = tecator.Tecator(rootdir, download=False)
|
_ = tecator.Tecator(rootdir, download=False)
|
||||||
|
|
||||||
def test_download_caching(self):
|
def test_download_caching(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
||||||
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
||||||
|
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
||||||
self.assertTrue('Split: Train' in train.__repr__())
|
self.assertTrue("Split: Train" in train.__repr__())
|
||||||
|
|
||||||
def test_download_train(self):
|
def test_download_train(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
train = tecator.Tecator(root=rootdir,
|
train = tecator.Tecator(root=rootdir,
|
||||||
train=True,
|
train=True,
|
||||||
download=True,
|
download=True,
|
||||||
@@ -67,7 +67,7 @@ class TestTecator(unittest.TestCase):
|
|||||||
self.assertEqual(x_train.shape[1], 100)
|
self.assertEqual(x_train.shape[1], 100)
|
||||||
|
|
||||||
def test_download_test(self):
|
def test_download_test(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
x_test, y_test = test.data, test.targets
|
x_test, y_test = test.data, test.targets
|
||||||
self.assertEqual(x_test.shape[0], 71)
|
self.assertEqual(x_test.shape[0], 71)
|
||||||
@@ -75,19 +75,19 @@ class TestTecator(unittest.TestCase):
|
|||||||
self.assertEqual(x_test.shape[1], 100)
|
self.assertEqual(x_test.shape[1], 100)
|
||||||
|
|
||||||
def test_class_to_idx(self):
|
def test_class_to_idx(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
_ = test.class_to_idx
|
_ = test.class_to_idx
|
||||||
|
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
x, y = test[0]
|
x, y = test[0]
|
||||||
self.assertEqual(x.shape[0], 100)
|
self.assertEqual(x.shape[0], 100)
|
||||||
self.assertIsInstance(y, int)
|
self.assertIsInstance(y, int)
|
||||||
|
|
||||||
def test_loadable_with_dataloader(self):
|
def test_loadable_with_dataloader(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
||||||
|
|
||||||
|
@@ -4,14 +4,13 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions import (activations, competitions, distances,
|
from prototorch.functions import (activations, competitions, distances,
|
||||||
initializers, losses)
|
initializers, losses, pooling)
|
||||||
|
|
||||||
|
|
||||||
class TestActivations(unittest.TestCase):
|
class TestActivations(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.flist = ['identity', 'sigmoid_beta', 'swish_beta']
|
self.flist = ["identity", "sigmoid_beta", "swish_beta"]
|
||||||
self.x = torch.randn(1024, 1)
|
self.x = torch.randn(1024, 1)
|
||||||
|
|
||||||
def test_registry(self):
|
def test_registry(self):
|
||||||
@@ -39,7 +38,7 @@ class TestActivations(unittest.TestCase):
|
|||||||
self.assertEqual(1, f(1))
|
self.assertEqual(1, f(1))
|
||||||
|
|
||||||
def test_unknown_deserialization(self):
|
def test_unknown_deserialization(self):
|
||||||
for funcname in ['blubb', 'foobar']:
|
for funcname in ["blubb", "foobar"]:
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = activations.get_activation(funcname)
|
_ = activations.get_activation(funcname)
|
||||||
|
|
||||||
@@ -52,7 +51,7 @@ class TestActivations(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_sigmoid_beta1(self):
|
def test_sigmoid_beta1(self):
|
||||||
actual = activations.sigmoid_beta(self.x, beta=torch.tensor(1))
|
actual = activations.sigmoid_beta(self.x, beta=1.0)
|
||||||
desired = torch.sigmoid(self.x)
|
desired = torch.sigmoid(self.x)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@@ -60,7 +59,7 @@ class TestActivations(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_swish_beta1(self):
|
def test_swish_beta1(self):
|
||||||
actual = activations.swish_beta(self.x, beta=torch.tensor(1))
|
actual = activations.swish_beta(self.x, beta=1.0)
|
||||||
desired = self.x * torch.sigmoid(self.x)
|
desired = self.x * torch.sigmoid(self.x)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@@ -76,7 +75,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_wtac(self):
|
def test_wtac(self):
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
actual = competitions.wtac(d, labels)
|
actual = competitions.wtac(d, labels)
|
||||||
desired = torch.tensor([2, 0])
|
desired = torch.tensor([2, 0])
|
||||||
@@ -86,7 +85,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_wtac_unequal_dist(self):
|
def test_wtac_unequal_dist(self):
|
||||||
d = torch.tensor([[2., 3., 4.], [2., 3., 1.]])
|
d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
|
||||||
labels = torch.tensor([0, 1, 1])
|
labels = torch.tensor([0, 1, 1])
|
||||||
actual = competitions.wtac(d, labels)
|
actual = competitions.wtac(d, labels)
|
||||||
desired = torch.tensor([0, 1])
|
desired = torch.tensor([0, 1])
|
||||||
@@ -96,7 +95,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_wtac_one_hot(self):
|
def test_wtac_one_hot(self):
|
||||||
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
|
||||||
labels = torch.tensor([[0, 1], [1, 0]])
|
labels = torch.tensor([[0, 1], [1, 0]])
|
||||||
actual = competitions.wtac(d, labels)
|
actual = competitions.wtac(d, labels)
|
||||||
desired = torch.tensor([[0, 1], [1, 0]])
|
desired = torch.tensor([[0, 1], [1, 0]])
|
||||||
@@ -105,42 +104,102 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_knnc_k1(self):
|
||||||
|
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||||
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
|
actual = competitions.knnc(d, labels, k=1)
|
||||||
|
desired = torch.tensor([2, 0])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestPooling(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_stratified_min(self):
|
def test_stratified_min(self):
|
||||||
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 0, 1, 2])
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = pooling.stratified_min_pooling(d, labels)
|
||||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_stratified_min_one_hot(self):
|
def test_stratified_min_one_hot(self):
|
||||||
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 0, 1, 2])
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
labels = torch.eye(3)[labels]
|
labels = torch.eye(3)[labels]
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = pooling.stratified_min_pooling(d, labels)
|
||||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_stratified_min_simple(self):
|
def test_stratified_min_trivial(self):
|
||||||
d = torch.tensor([[0., 2., 3.], [8., 0, 1]])
|
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 1, 2])
|
labels = torch.tensor([0, 1, 2])
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = pooling.stratified_min_pooling(d, labels)
|
||||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_knnc_k1(self):
|
def test_stratified_max(self):
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||||
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
|
actual = pooling.stratified_max_pooling(d, labels)
|
||||||
desired = torch.tensor([2, 0])
|
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_max_one_hot(self):
|
||||||
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||||
|
labels = torch.tensor([0, 0, 2, 1, 0])
|
||||||
|
labels = torch.nn.functional.one_hot(labels, num_classes=3)
|
||||||
|
actual = pooling.stratified_max_pooling(d, labels)
|
||||||
|
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_sum(self):
|
||||||
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
|
labels = torch.LongTensor([0, 0, 1, 2])
|
||||||
|
actual = pooling.stratified_sum_pooling(d, labels)
|
||||||
|
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_sum_one_hot(self):
|
||||||
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
|
labels = torch.eye(3)[labels]
|
||||||
|
actual = pooling.stratified_sum_pooling(d, labels)
|
||||||
|
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_prod(self):
|
||||||
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||||
|
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||||
|
actual = pooling.stratified_prod_pooling(d, labels)
|
||||||
|
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@@ -194,12 +253,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=2,
|
p=2,
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)**2
|
)**2)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=2)
|
decimal=2)
|
||||||
@@ -254,14 +313,14 @@ class TestDistances(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_lpnorm_pinf(self):
|
def test_lpnorm_pinf(self):
|
||||||
actual = distances.lpnorm_distance(self.x, self.y, p=float('inf'))
|
actual = distances.lpnorm_distance(self.x, self.y, p=float("inf"))
|
||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=float('inf'),
|
p=float("inf"),
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)
|
)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
@@ -275,12 +334,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=2,
|
p=2,
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)**2
|
)**2)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=2)
|
decimal=2)
|
||||||
@@ -293,12 +352,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=2,
|
p=2,
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)**2
|
)**2)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=2)
|
decimal=2)
|
||||||
@@ -311,8 +370,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
class TestInitializers(unittest.TestCase):
|
class TestInitializers(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.flist = [
|
self.flist = [
|
||||||
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
|
"zeros",
|
||||||
'stratified_random'
|
"ones",
|
||||||
|
"rand",
|
||||||
|
"randn",
|
||||||
|
"stratified_mean",
|
||||||
|
"stratified_random",
|
||||||
]
|
]
|
||||||
self.x = torch.tensor(
|
self.x = torch.tensor(
|
||||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
||||||
@@ -340,7 +403,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
self.assertEqual(1, f(1))
|
self.assertEqual(1, f(1))
|
||||||
|
|
||||||
def test_unknown_deserialization(self):
|
def test_unknown_deserialization(self):
|
||||||
for funcname in ['blubb', 'foobar']:
|
for funcname in ["blubb", "foobar"]:
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = initializers.get_initializer(funcname)
|
_ = initializers.get_initializer(funcname)
|
||||||
|
|
||||||
@@ -383,7 +446,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_equal1(self):
|
def test_stratified_mean_equal1(self):
|
||||||
pdist = torch.tensor([1, 1])
|
pdist = torch.tensor([1, 1])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]])
|
desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@@ -393,7 +456,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([1, 1])
|
pdist = torch.tensor([1, 1])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
False)
|
False)
|
||||||
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.]])
|
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@@ -402,8 +465,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_equal2(self):
|
def test_stratified_mean_equal2(self):
|
||||||
pdist = torch.tensor([2, 2])
|
pdist = torch.tensor([2, 2])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.],
|
desired = torch.tensor([[5.0, 5.0, 5.0], [5.0, 5.0, 5.0],
|
||||||
[1., 1., 1.]])
|
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@@ -413,8 +476,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([2, 2])
|
pdist = torch.tensor([2, 2])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
False)
|
False)
|
||||||
desired = torch.tensor([[0., -1., -2.], [0., -1., -2.], [0., 0., 0.],
|
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, -1.0, -2.0],
|
||||||
[0., 0., 0.]])
|
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@@ -423,8 +486,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_unequal(self):
|
def test_stratified_mean_unequal(self):
|
||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
|
||||||
[1., 1., 1.]])
|
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@@ -434,8 +497,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
False)
|
False)
|
||||||
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
|
||||||
[0., 0., 0.]])
|
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@@ -444,8 +507,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_unequal_one_hot(self):
|
def test_stratified_mean_unequal_one_hot(self):
|
||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
y = torch.eye(2)[self.y]
|
y = torch.eye(2)[self.y]
|
||||||
desired1 = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
|
||||||
[1., 1., 1.]])
|
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||||
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
|
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
|
||||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||||
@@ -460,8 +523,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
y = torch.eye(2)[self.y]
|
y = torch.eye(2)[self.y]
|
||||||
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
|
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
|
||||||
desired1 = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
desired1 = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
|
||||||
[0., 0., 0.]])
|
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||||
desired1,
|
desired1,
|
||||||
|
@@ -1,279 +0,0 @@
|
|||||||
"""ProtoTorch modules test suite."""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from prototorch.modules import losses, prototypes
|
|
||||||
|
|
||||||
|
|
||||||
class TestPrototypes(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.x = torch.tensor(
|
|
||||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
|
||||||
dtype=torch.float32)
|
|
||||||
self.y = torch.tensor([0, 0, 1, 1])
|
|
||||||
self.gen = torch.manual_seed(42)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_without_input_dim(self):
|
|
||||||
with self.assertRaises(NameError):
|
|
||||||
_ = prototypes.Prototypes1D(nclasses=2)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_without_nclasses(self):
|
|
||||||
with self.assertRaises(NameError):
|
|
||||||
_ = prototypes.Prototypes1D(input_dim=1)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_nclasses_1(self):
|
|
||||||
with self.assertWarns(UserWarning):
|
|
||||||
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_without_pdist(self):
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=6,
|
|
||||||
nclasses=2,
|
|
||||||
prototypes_per_class=4,
|
|
||||||
prototype_initializer='ones')
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.ones(8, 6)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_without_data(self):
|
|
||||||
pdist = [2, 2]
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
|
||||||
prototype_distribution=pdist,
|
|
||||||
prototype_initializer='zeros')
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.zeros(4, 3)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_proto_init_without_data(self):
|
|
||||||
with self.assertWarns(UserWarning):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=3,
|
|
||||||
nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=None)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_torch_pdist(self):
|
|
||||||
pdist = torch.tensor([2, 2])
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
|
||||||
prototype_distribution=pdist,
|
|
||||||
prototype_initializer='zeros')
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.zeros(4, 3)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_without_inputdim_with_data(self):
|
|
||||||
_ = prototypes.Prototypes1D(nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=[[[1.], [0.]], [1, 0]])
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_int_data(self):
|
|
||||||
_ = prototypes.Prototypes1D(nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=[[[1], [0]], [1, 0]])
|
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_without_data(self):
|
|
||||||
_ = prototypes.Prototypes1D(input_dim=1,
|
|
||||||
nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=None,
|
|
||||||
one_hot_labels=True)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_labels_false(self):
|
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
|
|
||||||
but the provided `data` has one-hot encoded labels.
|
|
||||||
"""
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=1,
|
|
||||||
nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=([[0.], [1.]], [[0, 1], [1, 0]]),
|
|
||||||
one_hot_labels=False)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
|
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
|
||||||
but the provided `data` does not contain one-hot encoded labels.
|
|
||||||
"""
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=1,
|
|
||||||
nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=([[0.], [1.]], [0, 1]),
|
|
||||||
one_hot_labels=True)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_labels_true(self):
|
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
|
||||||
but the provided `data` contains 2D targets but
|
|
||||||
does not contain one-hot encoded labels.
|
|
||||||
"""
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=1,
|
|
||||||
nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=([[0.], [1.]], [[0], [1]]),
|
|
||||||
one_hot_labels=True)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_int_dtype(self):
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=[[[1], [0]], [1, 0]],
|
|
||||||
dtype=torch.int32)
|
|
||||||
|
|
||||||
def test_prototypes1d_inputndim_with_data(self):
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(input_dim=1,
|
|
||||||
nclasses=1,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
data=[[1.], [1]])
|
|
||||||
|
|
||||||
def test_prototypes1d_inputdim_with_data(self):
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=2,
|
|
||||||
nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=[[[1.], [0.]], [1, 0]])
|
|
||||||
|
|
||||||
def test_prototypes1d_nclasses_with_data(self):
|
|
||||||
"""Test ValueError raise if provided `nclasses` is not the same
|
|
||||||
as the one computed from the provided `data`.
|
|
||||||
"""
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=1,
|
|
||||||
nclasses=1,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=[[[1.], [2.]], [1, 2]])
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_ppc(self):
|
|
||||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
|
||||||
prototypes_per_class=2,
|
|
||||||
prototype_initializer='zeros')
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.zeros(4, 3)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_pdist(self):
|
|
||||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
|
||||||
prototype_distribution=[6, 9],
|
|
||||||
prototype_initializer='zeros')
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.zeros(15, 3)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_func_initializer(self):
|
|
||||||
def my_initializer(*args, **kwargs):
|
|
||||||
return torch.full((2, 99), 99.0), torch.tensor([0, 1])
|
|
||||||
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=99,
|
|
||||||
nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer=my_initializer)
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = 99 * torch.ones(2, 99)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_forward(self):
|
|
||||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y])
|
|
||||||
protos, _ = p1()
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.ones(2, 3)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_dist_validate(self):
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
|
||||||
with self.assertWarns(UserWarning):
|
|
||||||
_ = p1._validate_prototype_distribution()
|
|
||||||
|
|
||||||
def test_prototypes1d_validate_extra_repr_not_empty(self):
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
|
||||||
rep = p1.extra_repr()
|
|
||||||
self.assertNotEqual(rep, '')
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
del self.x, self.y, self.gen
|
|
||||||
_ = torch.seed()
|
|
||||||
|
|
||||||
|
|
||||||
class TestLosses(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_glvqloss_init(self):
|
|
||||||
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
|
|
||||||
|
|
||||||
def test_glvqloss_forward_1ppc(self):
|
|
||||||
criterion = losses.GLVQLoss(margin=0,
|
|
||||||
squashing='sigmoid_beta',
|
|
||||||
beta=100)
|
|
||||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
|
||||||
labels = torch.tensor([0, 1])
|
|
||||||
targets = torch.ones(100)
|
|
||||||
outputs = [d, labels]
|
|
||||||
loss = criterion(outputs, targets)
|
|
||||||
loss_value = loss.item()
|
|
||||||
self.assertAlmostEqual(loss_value, 0.0)
|
|
||||||
|
|
||||||
def test_glvqloss_forward_2ppc(self):
|
|
||||||
criterion = losses.GLVQLoss(margin=0,
|
|
||||||
squashing='sigmoid_beta',
|
|
||||||
beta=100)
|
|
||||||
d = torch.stack([
|
|
||||||
torch.ones(100),
|
|
||||||
torch.ones(100),
|
|
||||||
torch.zeros(100),
|
|
||||||
torch.ones(100)
|
|
||||||
],
|
|
||||||
dim=1)
|
|
||||||
labels = torch.tensor([0, 0, 1, 1])
|
|
||||||
targets = torch.ones(100)
|
|
||||||
outputs = [d, labels]
|
|
||||||
loss = criterion(outputs, targets)
|
|
||||||
loss_value = loss.item()
|
|
||||||
self.assertAlmostEqual(loss_value, 0.0)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
pass
|
|
15
tox.ini
15
tox.ini
@@ -1,15 +0,0 @@
|
|||||||
# tox (https://tox.readthedocs.io/) is a tool for running tests
|
|
||||||
# in multiple virtualenvs. This configuration file will run the
|
|
||||||
# test suite on all supported python versions. To use it, "pip install tox"
|
|
||||||
# and then run "tox" from this directory.
|
|
||||||
|
|
||||||
[tox]
|
|
||||||
envlist = py36,py37,py38
|
|
||||||
|
|
||||||
[testenv]
|
|
||||||
deps =
|
|
||||||
pytest
|
|
||||||
coverage
|
|
||||||
commands =
|
|
||||||
pip install -e .
|
|
||||||
coverage run -m pytest
|
|
Reference in New Issue
Block a user