Compare commits
208 Commits
v0.1.0-dev
...
v0.5.1
Author | SHA1 | Date | |
---|---|---|---|
|
6ffd14e85c | ||
|
40c1021c20 | ||
|
acf3272fd7 | ||
|
c73f8e7a28 | ||
|
bf23d5f7f8 | ||
|
bcde3f6ac8 | ||
|
d5229b1750 | ||
|
fc4b143fbb | ||
|
11cfa79746 | ||
|
d0ae94f2af | ||
|
2c908a8361 | ||
|
e4257ec1f1 | ||
|
aaad2b8626 | ||
|
c0c0044a42 | ||
|
47d7f5831f | ||
|
4f1c879528 | ||
|
2272c55092 | ||
|
b03c9b1d3c | ||
|
0c28eda706 | ||
|
7bc0bfa3ab | ||
|
827958a28a | ||
|
8200e1d3d8 | ||
|
729b20e9ab | ||
|
ca8ac7a43b | ||
|
b724a28a6f | ||
|
1e0a8392a2 | ||
|
2eb7b05653 | ||
|
d8a0b2dfcc | ||
|
2a7394b593 | ||
|
b1e64c8b8b | ||
|
70cf17607e | ||
|
b1568a550a | ||
|
e8e803e8ef | ||
|
2c453265fe | ||
|
7336d35fee | ||
|
bc18952c05 | ||
|
8e8d0b9c2c | ||
|
5a7da2b40b | ||
|
b6d38f442b | ||
|
8e8851d962 | ||
|
27b43b06a7 | ||
|
ff69eb1256 | ||
|
4ca581909a | ||
|
2722d976f5 | ||
|
946cda00d2 | ||
|
8227525c82 | ||
|
e61ae73749 | ||
|
040d1ee9e8 | ||
|
7f0da894fa | ||
|
62726df278 | ||
|
0ba09db6fe | ||
|
87334c11e6 | ||
|
40ef3aeda2 | ||
|
94fe4435a8 | ||
|
c204bc8e1f | ||
|
00615ae837 | ||
|
9f5f0d12dd | ||
|
8a291f7bfb | ||
|
21e3e3b82d | ||
|
a6bd6e130a | ||
|
fcdfa52892 | ||
|
73e6fe384e | ||
|
aff7a385a3 | ||
|
1e23ba05fa | ||
|
ee30d4da5b | ||
|
14508f0600 | ||
|
e3f8828da4 | ||
|
30adbf705c | ||
|
ee42fd68b1 | ||
|
736d9a6349 | ||
|
0055e15bc1 | ||
|
b2e1df7308 | ||
|
b935e9caf3 | ||
|
503ef0e05f | ||
|
dc6248413c | ||
|
e73b70ceb7 | ||
|
639198e774 | ||
|
768d969f89 | ||
|
aec422c277 | ||
|
6c14170de6 | ||
|
36a330aa66 | ||
|
acd4ac6a86 | ||
|
abe64cfe8f | ||
|
caae95d01d | ||
|
088429a16a | ||
|
b6145223c8 | ||
|
09256956f3 | ||
|
0ca90fdcee | ||
|
be21412f8a | ||
|
ae6bc47f87 | ||
|
7bb93f027a | ||
|
bc20acd63b | ||
|
a864cf5d4d | ||
|
2175f524e8 | ||
|
c1c21e92df | ||
|
2b676ee06e | ||
|
dda2f1d779 | ||
|
3a8388e24f | ||
|
a9eef8ae6d | ||
|
ac3091d8da | ||
|
ce3991de94 | ||
|
47b4b9bcb1 | ||
|
19475d7e2b | ||
|
269eb8ba25 | ||
|
b06ded683d | ||
|
466e9bde6b | ||
|
fc7d64aaea | ||
|
9a7d3192c0 | ||
|
e686adbea1 | ||
|
b7d53aa5f1 | ||
|
9b663477fd | ||
|
a70166280a | ||
|
a083c4b276 | ||
|
40751aa50a | ||
|
7c30ffe2c7 | ||
|
e1d56595c1 | ||
|
4540c8848e | ||
|
c88f288d12 | ||
|
e2918dffed | ||
|
7d9dfc27ee | ||
|
ae75b9ebf7 | ||
|
34973808b8 | ||
|
c42df6e203 | ||
|
101b50f4e6 | ||
|
db842b79bb | ||
|
98a8fc52fa | ||
|
6796ec494f | ||
|
cd9303267b | ||
|
599dfc3fda | ||
|
5b2ab34232 | ||
|
429570323e | ||
|
3edb13baf4 | ||
|
42cedbb2b8 | ||
|
2322876eb6 | ||
|
bc7df1059f | ||
|
4c7c9cc34a | ||
|
e39f307194 | ||
|
e2867f696e | ||
|
30dc0ea8b1 | ||
|
895281aabd | ||
|
a55320a65b | ||
|
559f4acc73 | ||
|
9b5bccc39d | ||
|
a8a99f6971 | ||
|
58efa5a4cf | ||
|
9672aab8e2 | ||
|
d5ab9c3771 | ||
|
3e6aa6a20b | ||
|
b138277608 | ||
|
9ccbec52f7 | ||
|
cd652508b9 | ||
|
fa72c7156e | ||
|
6e72b9267a | ||
|
8a4a596035 | ||
|
0cfbc0473b | ||
|
cf0659d881 | ||
|
d17b9a3346 | ||
|
532f63b1de | ||
|
c11a3860df | ||
|
dab91e471a | ||
|
a167565857 | ||
|
e063625486 | ||
|
89eb5358a0 | ||
|
5c59515128 | ||
|
7eb7a6b194 | ||
|
5811c4b9f9 | ||
|
7b1887d56e | ||
|
63a25e7a38 | ||
|
a0f20a40f6 | ||
|
88cbe0a126 | ||
|
a3548e0ddd | ||
|
3cfbc49254 | ||
|
2b82830590 | ||
|
553b1e1a65 | ||
|
a9d2855323 | ||
|
cf7d7b5d9d | ||
|
a22c752342 | ||
|
4158586cb9 | ||
|
f80d9648c3 | ||
|
e54bf07030 | ||
|
8c629c0cb1 | ||
|
8f3a43f62a | ||
|
955661af95 | ||
|
c54d14c55e | ||
|
6090aad176 | ||
|
1ec7bd261b | ||
|
da3b0cc262 | ||
|
f640a22cf2 | ||
|
c843ace63d | ||
|
242c9de3b6 | ||
|
438a5b9360 | ||
|
f98f3d095e | ||
|
21b0279839 | ||
|
b19cbcb76a | ||
|
7d5ab81dbf | ||
|
bde408a80e | ||
|
900955d67a | ||
|
3757c937b3 | ||
|
38f637aaeb | ||
|
6ddfe48a95 | ||
|
bf0e694321 | ||
|
e2c9848120 | ||
|
dc60b7e5b5 | ||
|
c21913fdd4 | ||
|
59e31f94ab | ||
|
cddefa9b0d | ||
|
26d71fdd60 | ||
|
ced8f532dd |
@@ -1,21 +1,13 @@
|
||||
[bumpversion]
|
||||
current_version = 0.1.0-dev0
|
||||
current_version = 0.5.1
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||
serialize =
|
||||
{major}.{minor}.{patch}-{release}{build}
|
||||
{major}.{minor}.{patch}
|
||||
|
||||
[bumpversion:part:release]
|
||||
optional_value = prod
|
||||
first_value = dev
|
||||
values =
|
||||
dev
|
||||
rc
|
||||
prod
|
||||
|
||||
[bumpversion:file:setup.py]
|
||||
|
||||
[bumpversion:file:./prototorch/__init__.py]
|
||||
|
||||
[bumpversion:file:./docs/source/conf.py]
|
||||
|
15
.codacy.yml
Normal file
15
.codacy.yml
Normal file
@@ -0,0 +1,15 @@
|
||||
# To validate the contents of your configuration file
|
||||
# run the following command in the folder where the configuration file is located:
|
||||
# codacy-analysis-cli validate-configuration --directory `pwd`
|
||||
# To analyse, run:
|
||||
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
|
||||
---
|
||||
engines:
|
||||
pylintpython3:
|
||||
exclude_paths:
|
||||
- config/engines.yml
|
||||
remark-lint:
|
||||
exclude_paths:
|
||||
- config/engines.yml
|
||||
exclude_paths:
|
||||
- 'tests/**'
|
31
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
31
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Install Prototorch by running '...'
|
||||
2. Run script '...'
|
||||
3. See errors
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS: [e.g. Ubuntu 20.10]
|
||||
- Prototorch Version: [e.g. v0.4.0]
|
||||
- Python Version: [e.g. 3.9.5]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
4
.github/workflows/pythonapp.yml
vendored
4
.github/workflows/pythonapp.yml
vendored
@@ -1,7 +1,7 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Tests
|
||||
name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .
|
||||
pip install .[all]
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
pip install flake8
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@@ -154,4 +154,5 @@ scratch*
|
||||
# End of https://www.gitignore.io/api/visualstudiocode
|
||||
.vscode/
|
||||
|
||||
reports
|
||||
reports
|
||||
artifacts
|
||||
|
54
.pre-commit-config.yaml
Normal file
54
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,54 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.0.1
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- id: check-ast
|
||||
- id: check-case-conflict
|
||||
|
||||
|
||||
- repo: https://github.com/myint/autoflake
|
||||
rev: v1.4
|
||||
hooks:
|
||||
- id: autoflake
|
||||
|
||||
- repo: http://github.com/PyCQA/isort
|
||||
rev: 5.8.0
|
||||
hooks:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: 'v0.902'
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: prototorch
|
||||
additional_dependencies: [types-pkg_resources]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: 'v0.31.0' # Use the sha / tag you want to point at
|
||||
hooks:
|
||||
- id: yapf
|
||||
|
||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||
rev: v1.9.0 # Use the ref you want to point at
|
||||
hooks:
|
||||
- id: python-use-type-annotations
|
||||
- id: python-no-log-warn
|
||||
- id: python-check-blanket-noqa
|
||||
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.19.4
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/jorisroovers/gitlint
|
||||
rev: "v0.15.1"
|
||||
hooks:
|
||||
- id: gitlint
|
||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
27
.readthedocs.yml
Normal file
27
.readthedocs.yml
Normal file
@@ -0,0 +1,27 @@
|
||||
# .readthedocs.yml
|
||||
# Read the Docs configuration file
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
# Required
|
||||
version: 2
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/source/conf.py
|
||||
fail_on_warning: true
|
||||
|
||||
# Build documentation with MkDocs
|
||||
# mkdocs:
|
||||
# configuration: mkdocs.yml
|
||||
|
||||
# Optionally build your docs in additional formats such as PDF and ePub
|
||||
formats: all
|
||||
|
||||
# Optionally set the version of Python and requirements required to build your docs
|
||||
python:
|
||||
version: 3.8
|
||||
install:
|
||||
- method: pip
|
||||
path: .
|
||||
extra_requirements:
|
||||
- all
|
36
.travis.yml
Normal file
36
.travis.yml
Normal file
@@ -0,0 +1,36 @@
|
||||
dist: bionic
|
||||
sudo: false
|
||||
language: python
|
||||
python: 3.8
|
||||
cache:
|
||||
directories:
|
||||
- "$HOME/.cache/pip"
|
||||
- "./tests/artifacts"
|
||||
- "$HOME/datasets"
|
||||
install:
|
||||
- pip install .[all] --progress-bar off
|
||||
|
||||
# Generate code coverage report
|
||||
script:
|
||||
- coverage run -m pytest
|
||||
|
||||
# Push the results to codecov
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
|
||||
# Publish on PyPI
|
||||
deploy:
|
||||
provider: pypi
|
||||
username: __token__
|
||||
password:
|
||||
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
||||
on:
|
||||
tags: true
|
||||
skip_existing: true
|
||||
|
||||
# The password is encrypted with:
|
||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
||||
# See https://docs.travis-ci.com/user/deployment/pypi and
|
||||
# https://github.com/travis-ci/travis.rb#installation
|
||||
# for more details
|
||||
# Note: The encrypt command does not work well in ZSH.
|
@@ -1,9 +1,13 @@
|
||||
include .bumpversion.cfg
|
||||
include LICENSE
|
||||
include tox.ini
|
||||
include *.md
|
||||
include *.txt
|
||||
include *.yml
|
||||
recursive-include docs *.bat
|
||||
recursive-include docs *.png
|
||||
recursive-include docs *.py
|
||||
recursive-include docs *.rst
|
||||
recursive-include docs Makefile
|
||||
recursive-include examples *.py
|
||||
recursive-include tests *.py
|
||||
|
80
README.md
80
README.md
@@ -1,49 +1,73 @@
|
||||
# ProtoTorch
|
||||
# ProtoTorch: Prototype Learning in PyTorch
|
||||
|
||||
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge research in
|
||||
prototype-based machine learning algorithms.
|
||||

|
||||
|
||||

|
||||
[](https://travis-ci.org/si-cim/prototorch)
|
||||

|
||||
[](https://github.com/si-cim/prototorch/releases)
|
||||
[](https://pypi.org/project/prototorch/)
|
||||
[](https://codecov.io/gh/si-cim/prototorch)
|
||||
[](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&utm_medium=referral&utm_content=si-cim/prototorch&utm_campaign=Badge_Grade)
|
||||

|
||||
[](https://github.com/si-cim/prototorch/blob/master/LICENSE)
|
||||
|
||||
*Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow)
|
||||
|
||||
## Description
|
||||
|
||||
This is a Python toolbox brewed at the Mittweida University of Applied Sciences
|
||||
in Germany for bleeding-edge research in Learning Vector Quantization (LVQ)
|
||||
and potentially other prototype-based methods. Although, there are
|
||||
other (perhaps more extensive) LVQ toolboxes available out there, the focus of
|
||||
ProtoTorch is ease-of-use, extensibility and speed.
|
||||
|
||||
Many popular prototype-based Machine Learning (ML) algorithms like K-Nearest
|
||||
Neighbors (KNN), Generalized Learning Vector Quantization (GLVQ) and Generalized
|
||||
Matrix Learning Vector Quantization (GMLVQ) are implemented using the "nn" API
|
||||
provided by PyTorch.
|
||||
in Germany for bleeding-edge research in Prototype-based Machine Learning
|
||||
methods and other interpretable models. The focus of ProtoTorch is ease-of-use,
|
||||
extensibility and speed.
|
||||
|
||||
## Installation
|
||||
|
||||
ProtoTorch can be installed using `pip`.
|
||||
```bash
|
||||
pip install -U prototorch
|
||||
```
|
||||
pip install prototorch
|
||||
To also install the extras, use
|
||||
```bash
|
||||
pip install -U prototorch[all]
|
||||
```
|
||||
|
||||
*Note: If you're using [ZSH](https://www.zsh.org/) (which is also the default
|
||||
shell on MacOS now), the square brackets `[ ]` have to be escaped like so:
|
||||
`\[\]`, making the install command `pip install -U prototorch\[all\]`.*
|
||||
|
||||
To install the bleeding-edge features and improvements:
|
||||
```
|
||||
```bash
|
||||
git clone https://github.com/si-cim/prototorch.git
|
||||
git checkout dev
|
||||
cd prototorch
|
||||
pip install -e .
|
||||
git checkout dev
|
||||
pip install -e .[all]
|
||||
```
|
||||
|
||||
## Usage
|
||||
## Documentation
|
||||
|
||||
ProtoTorch is modular. It is very easy to use the modular pieces provided by
|
||||
ProtoTorch, like the layers, losses, callbacks and metrics to build your own
|
||||
prototype-based(instance-based) models. These pieces blend-in seamlessly with
|
||||
numpy and PyTorch to allow you mix and match the modules from ProtoTorch with
|
||||
other PyTorch modules.
|
||||
The documentation is available at <https://www.prototorch.ml/en/latest/>. Should
|
||||
that link not work try <https://prototorch.readthedocs.io/en/latest/>.
|
||||
|
||||
ProtoTorch comes prepackaged with many popular LVQ algorithms in a convenient
|
||||
API, with more algorithms and techniques coming soon. If you would simply like
|
||||
to be able to use those algorithms to train large ML models on a GPU, ProtoTorch
|
||||
lets you do this without requiring a black-belt in high-performance Tensor
|
||||
computation.
|
||||
## Contribution
|
||||
|
||||
This repository contains definition for [git hooks](https://githooks.com).
|
||||
[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
|
||||
Please install the hooks by running
|
||||
```bash
|
||||
pre-commit install
|
||||
pre-commit install --hook-type commit-msg
|
||||
```
|
||||
before creating the first commit.
|
||||
|
||||
## Bibtex
|
||||
|
||||
If you would like to cite the package, please use this:
|
||||
```bibtex
|
||||
@misc{Ravichandran2020b,
|
||||
author = {Ravichandran, J},
|
||||
title = {ProtoTorch},
|
||||
year = {2020},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/si-cim/prototorch}}
|
||||
}
|
||||
|
19
RELEASE.md
Normal file
19
RELEASE.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# ProtoTorch Releases
|
||||
|
||||
## Release 0.5.0
|
||||
|
||||
- Breaking: Removed deprecated `prototorch.modules.Prototypes1D`.
|
||||
- Use `prototorch.components.LabeledComponents` instead.
|
||||
|
||||
## Release 0.2.0
|
||||
|
||||
- Fixes in example scripts.
|
||||
|
||||
## Release 0.1.1-dev0
|
||||
|
||||
- Minor bugfixes.
|
||||
- 100% line coverage.
|
||||
|
||||
## Release 0.1.0-dev0
|
||||
|
||||
Initial public release of ProtoTorch.
|
20
docs/Makefile
Normal file
20
docs/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= python3 -m sphinx
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
35
docs/make.bat
Normal file
35
docs/make.bat
Normal file
@@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
4
docs/requirements.txt
Normal file
4
docs/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
torch==1.6.0
|
||||
matplotlib==3.1.2
|
||||
sphinx_rtd_theme==0.5.0
|
||||
sphinxcontrib-katex==0.6.1
|
BIN
docs/source/_static/img/horizontal-lockup.png
Normal file
BIN
docs/source/_static/img/horizontal-lockup.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 88 KiB |
57
docs/source/api.rst
Normal file
57
docs/source/api.rst
Normal file
@@ -0,0 +1,57 @@
|
||||
.. ProtoTorch API Reference
|
||||
|
||||
ProtoTorch API Reference
|
||||
======================================
|
||||
|
||||
Datasets
|
||||
--------------------------------------
|
||||
|
||||
Common Datasets
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: prototorch.datasets
|
||||
:members:
|
||||
|
||||
|
||||
Abstract Datasets
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Abstract Datasets are used to build your own datasets.
|
||||
|
||||
.. autoclass:: prototorch.datasets.abstract.NumpyDataset
|
||||
:members:
|
||||
|
||||
Functions
|
||||
--------------------------------------
|
||||
|
||||
**Dimensions:**
|
||||
|
||||
- :math:`B` ... Batch size
|
||||
- :math:`P` ... Number of prototypes
|
||||
- :math:`n_x` ... Data dimension for vectorial data
|
||||
- :math:`n_w` ... Data dimension for vectorial prototypes
|
||||
|
||||
Activations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: prototorch.functions.activations
|
||||
:members:
|
||||
:exclude-members: register_activation, get_activation
|
||||
:undoc-members:
|
||||
|
||||
Distances
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: prototorch.functions.distances
|
||||
:members:
|
||||
:exclude-members: sed
|
||||
:undoc-members:
|
||||
|
||||
Modules
|
||||
--------------------------------------
|
||||
.. automodule:: prototorch.modules
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
Utilities
|
||||
--------------------------------------
|
||||
.. automodule:: prototorch.utils
|
||||
:members:
|
||||
:undoc-members:
|
192
docs/source/conf.py
Normal file
192
docs/source/conf.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file only contains a selection of the most common options. For a full
|
||||
# list see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../"))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "ProtoTorch"
|
||||
copyright = "2021, Jensun Ravichandran"
|
||||
author = "Jensun Ravichandran"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
#
|
||||
release = "0.5.1"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
#
|
||||
needs_sphinx = "1.6"
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named "sphinx.ext.*") or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
"recommonmark",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.doctest",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.todo",
|
||||
"sphinx.ext.coverage",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib.katex",
|
||||
'sphinx_autodoc_typehints',
|
||||
]
|
||||
|
||||
# katex_prerender = True
|
||||
katex_prerender = False
|
||||
|
||||
napoleon_use_ivar = True
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = "index"
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = []
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use. Choose from:
|
||||
# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
|
||||
# "monokai", "perldoc", "pastie", "borland", "trac", "native", "fruity", "bw",
|
||||
# "vim", "vs", "tango", "rrt", "xcode", "igor", "paraiso-light", "paraiso-dark",
|
||||
# "lovelace", "algol", "algol_nu", "arduino", "rainbo w_dash", "abap",
|
||||
# "solarized-dark", "solarized-light", "sas", "stata", "stata-light",
|
||||
# "stata-dark", "inkpot"]
|
||||
pygments_style = "monokai"
|
||||
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
todo_include_todos = True
|
||||
|
||||
# Disable docstring inheritance
|
||||
autodoc_inherit_docstrings = False
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
# https://sphinx-themes.org/
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
html_logo = "_static/img/horizontal-lockup.png"
|
||||
|
||||
html_theme_options = {
|
||||
"logo_only": True,
|
||||
"display_version": True,
|
||||
"prev_next_buttons_location": "bottom",
|
||||
"style_external_links": False,
|
||||
"style_nav_header_background": "#ffffff",
|
||||
# Toc options
|
||||
"collapse_navigation": True,
|
||||
"sticky_navigation": True,
|
||||
"navigation_depth": 4,
|
||||
"includehidden": True,
|
||||
"titles_only": False,
|
||||
}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ["_static"]
|
||||
|
||||
html_css_files = [
|
||||
"https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.css",
|
||||
]
|
||||
|
||||
# -- Options for HTMLHelp output ------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = "protoflowdoc"
|
||||
|
||||
# -- Options for LaTeX output ---------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ("letterpaper" or "a4paper").
|
||||
#
|
||||
# "papersize": "letterpaper",
|
||||
# The font size ("10pt", "11pt" or "12pt").
|
||||
#
|
||||
# "pointsize": "10pt",
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# "preamble": "",
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# "figure_align": "htbp",
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(
|
||||
master_doc,
|
||||
"prototorch.tex",
|
||||
"ProtoTorch Documentation",
|
||||
"Jensun Ravichandran",
|
||||
"manual",
|
||||
),
|
||||
]
|
||||
|
||||
# -- Options for manual page output ---------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author],
|
||||
1)]
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(
|
||||
master_doc,
|
||||
"prototorch",
|
||||
"ProtoTorch Documentation",
|
||||
author,
|
||||
"prototorch",
|
||||
"Prototype-based machine learning in PyTorch.",
|
||||
"Miscellaneous",
|
||||
),
|
||||
]
|
||||
|
||||
# Example configuration for intersphinx: refer to the Python standard library.
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/", None),
|
||||
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
|
||||
"torch": ('https://pytorch.org/docs/stable/', None),
|
||||
"pytorch_lightning":
|
||||
("https://pytorch-lightning.readthedocs.io/en/stable/", None),
|
||||
}
|
||||
|
||||
# -- Options for Epub output ----------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-epub-output
|
||||
|
||||
epub_cover = ()
|
||||
version = release
|
22
docs/source/index.rst
Normal file
22
docs/source/index.rst
Normal file
@@ -0,0 +1,22 @@
|
||||
.. ProtoTorch documentation master file
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
About ProtoTorch
|
||||
================
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 3
|
||||
:caption: Contents:
|
||||
|
||||
self
|
||||
api
|
||||
|
||||
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge
|
||||
research in prototype-based machine learning algorithms.
|
||||
|
||||
Indices
|
||||
=======
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
@@ -1,103 +0,0 @@
|
||||
"""ProtoTorch GLVQ example using 2D Iris data"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import AddPrototypes1D
|
||||
|
||||
# Prepare and preprocess the data
|
||||
scaler = StandardScaler()
|
||||
x_train, y_train = load_iris(True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
scaler.fit(x_train)
|
||||
x_train = scaler.transform(x_train)
|
||||
|
||||
|
||||
# Define the GLVQ model
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.p1 = AddPrototypes1D(input_dim=2,
|
||||
prototypes_per_class=1,
|
||||
nclasses=3,
|
||||
prototype_initializer='zeros')
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
plabels = self.p1.prototype_labels
|
||||
dis = euclidean_distance(x, protos)
|
||||
return dis, plabels
|
||||
|
||||
|
||||
# Build the GLVQ model
|
||||
model = Model()
|
||||
|
||||
# Optimize using SGD optimizer from `torch.optim`
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
||||
|
||||
# Training loop
|
||||
fig = plt.figure('Prototype Visualization')
|
||||
for epoch in range(70):
|
||||
# Compute loss.
|
||||
distances, plabels = model(torch.tensor(x_train))
|
||||
loss = criterion([distances, plabels], torch.tensor(y_train))
|
||||
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():02.02f}')
|
||||
|
||||
# Take a gradient descent step
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.p1.prototypes.data.numpy()
|
||||
|
||||
# Visualize the data and the prototypes
|
||||
ax = fig.gca()
|
||||
ax.cla()
|
||||
cmap = 'viridis'
|
||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
||||
ax.scatter(protos[:, 0],
|
||||
protos[:, 1],
|
||||
c=plabels,
|
||||
cmap=cmap,
|
||||
edgecolor='k',
|
||||
marker='D',
|
||||
s=50)
|
||||
|
||||
# Paint decision regions
|
||||
border = 1
|
||||
resolution = 50
|
||||
x = np.vstack((x_train, protos))
|
||||
x_min, x_max = x[:, 0].min(), x[:, 0].max()
|
||||
y_min, y_max = x[:, 1].min(), x[:, 1].max()
|
||||
x_min, x_max = x_min - border, x_max + border
|
||||
y_min, y_max = y_min - border, y_max + border
|
||||
try:
|
||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1.0 / resolution),
|
||||
np.arange(y_min, y_max, 1.0 / resolution))
|
||||
except ValueError as ve:
|
||||
print(ve)
|
||||
raise ValueError(f'x_min: {x_min}, x_max: {x_max}. '
|
||||
f'x_min - x_max is {x_max - x_min}.')
|
||||
except MemoryError as me:
|
||||
print(me)
|
||||
raise ValueError('Too many points. ' 'Try reducing the resolution.')
|
||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||
|
||||
torch_input = torch.from_numpy(mesh_input)
|
||||
d = model(torch_input)[0]
|
||||
y_pred = np.argmin(d.detach().numpy(), axis=1)
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
|
||||
# Plot voronoi regions
|
||||
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
|
||||
|
||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||
plt.pause(0.1)
|
65
examples/new_components.py
Normal file
65
examples/new_components.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""This example script shows the usage of the new components architecture.
|
||||
|
||||
Serialization/deserialization also works as expected.
|
||||
"""
|
||||
|
||||
# DATASET
|
||||
import torch
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
scaler = StandardScaler()
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
scaler.fit(x_train)
|
||||
x_train = scaler.transform(x_train)
|
||||
|
||||
x_train = torch.Tensor(x_train)
|
||||
y_train = torch.Tensor(y_train)
|
||||
num_classes = len(torch.unique(y_train))
|
||||
|
||||
# CREATE NEW COMPONENTS
|
||||
from prototorch.components import *
|
||||
from prototorch.components.initializers import *
|
||||
|
||||
unsupervised = Components(6, SelectionInitializer(x_train))
|
||||
print(unsupervised())
|
||||
|
||||
prototypes = LabeledComponents(
|
||||
(3, 2), StratifiedSelectionInitializer(x_train, y_train))
|
||||
print(prototypes())
|
||||
|
||||
components = ReasoningComponents(
|
||||
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
|
||||
print(components())
|
||||
|
||||
# TEST SERIALIZATION
|
||||
import io
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(unsupervised, save)
|
||||
save.seek(0)
|
||||
serialized_unsupervised = torch.load(save)
|
||||
|
||||
assert torch.all(unsupervised.components == serialized_unsupervised.components
|
||||
), "Serialization of Components failed."
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(prototypes, save)
|
||||
save.seek(0)
|
||||
serialized_prototypes = torch.load(save)
|
||||
|
||||
assert torch.all(prototypes.components == serialized_prototypes.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(prototypes.component_labels == serialized_prototypes.
|
||||
component_labels), "Serialization of Components failed."
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(components, save)
|
||||
save.seek(0)
|
||||
serialized_components = torch.load(save)
|
||||
|
||||
assert torch.all(components.components == serialized_components.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(components.reasonings == serialized_components.reasonings
|
||||
), "Serialization of Components failed."
|
@@ -1 +1,46 @@
|
||||
__version__ = '0.1.0-dev0'
|
||||
"""ProtoTorch package."""
|
||||
|
||||
import pkgutil
|
||||
from typing import List
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from . import components, datasets, functions, modules, utils
|
||||
from .datasets import *
|
||||
|
||||
# Core Setup
|
||||
__version__ = "0.5.1"
|
||||
|
||||
__all_core__ = [
|
||||
"datasets",
|
||||
"functions",
|
||||
"modules",
|
||||
"components",
|
||||
"utils",
|
||||
]
|
||||
|
||||
# Plugin Loader
|
||||
__path__: List[str] = pkgutil.extend_path(__path__, __name__)
|
||||
|
||||
|
||||
def discover_plugins():
|
||||
return {
|
||||
entry_point.name: entry_point.load()
|
||||
for entry_point in pkg_resources.iter_entry_points(
|
||||
"prototorch.plugins")
|
||||
}
|
||||
|
||||
|
||||
discovered_plugins = discover_plugins()
|
||||
locals().update(discovered_plugins)
|
||||
|
||||
# Generate combines __version__ and __all__
|
||||
version_plugins = "\n".join([
|
||||
"- " + name + ": v" + plugin.__version__
|
||||
for name, plugin in discovered_plugins.items()
|
||||
])
|
||||
if version_plugins != "":
|
||||
version_plugins = "\nPlugins: \n" + version_plugins
|
||||
|
||||
version = "core: v" + __version__ + version_plugins
|
||||
__all__ = __all_core__ + list(discovered_plugins.keys())
|
||||
|
2
prototorch/components/__init__.py
Normal file
2
prototorch/components/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from prototorch.components.components import *
|
||||
from prototorch.components.initializers import *
|
270
prototorch/components/components.py
Normal file
270
prototorch/components/components.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""ProtoTorch components modules."""
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||
ComponentsInitializer,
|
||||
EqualLabelsInitializer,
|
||||
UnequalLabelsInitializer,
|
||||
ZeroReasoningsInitializer)
|
||||
|
||||
from .initializers import parse_data_arg
|
||||
|
||||
|
||||
def get_labels_object(distribution):
|
||||
if isinstance(distribution, dict):
|
||||
if "num_classes" in distribution.keys():
|
||||
labels = EqualLabelsInitializer(
|
||||
distribution["num_classes"],
|
||||
distribution["prototypes_per_class"])
|
||||
else:
|
||||
clabels = list(distribution.keys())
|
||||
dist = list(distribution.values())
|
||||
labels = UnequalLabelsInitializer(dist, clabels)
|
||||
elif isinstance(distribution, tuple):
|
||||
num_classes, prototypes_per_class = distribution
|
||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||
elif isinstance(distribution, list):
|
||||
labels = UnequalLabelsInitializer(distribution)
|
||||
else:
|
||||
msg = f"`distribution` not understood." \
|
||||
f"You have provided: {distribution=}."
|
||||
raise ValueError(msg)
|
||||
return labels
|
||||
|
||||
|
||||
def _precheck_initializer(initializer):
|
||||
if not isinstance(initializer, ComponentsInitializer):
|
||||
emsg = f"`initializer` has to be some subtype of " \
|
||||
f"{ComponentsInitializer}. " \
|
||||
f"You have provided: {initializer=} instead."
|
||||
raise TypeError(emsg)
|
||||
|
||||
|
||||
class LinearMapping(torch.nn.Module):
|
||||
"""LinearMapping is a learnable Mapping Matrix."""
|
||||
def __init__(self,
|
||||
mapping_shape=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_linearmapping=None):
|
||||
super().__init__()
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
if initialized_linearmapping is not None:
|
||||
self._register_mapping(initialized_linearmapping)
|
||||
if num_components is not None or initializer is not None:
|
||||
wmsg = "Arguments ignored while initializing Components"
|
||||
warnings.warn(wmsg)
|
||||
else:
|
||||
self._initialize_mapping(mapping_shape, initializer)
|
||||
|
||||
@property
|
||||
def mapping_shape(self):
|
||||
return self._omega.shape
|
||||
|
||||
def _register_mapping(self, components):
|
||||
self.register_parameter("_omega", Parameter(components))
|
||||
|
||||
def _initialize_mapping(self, mapping_shape, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
_mapping = initializer.generate(mapping_shape)
|
||||
self._register_mapping(_mapping)
|
||||
|
||||
@property
|
||||
def mapping(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._omega.detach()
|
||||
|
||||
def forward(self):
|
||||
return self._omega
|
||||
|
||||
|
||||
class Components(torch.nn.Module):
|
||||
"""Components is a set of learnable Tensors."""
|
||||
def __init__(self,
|
||||
num_components=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
super().__init__()
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
if initialized_components is not None:
|
||||
self._register_components(initialized_components)
|
||||
if num_components is not None or initializer is not None:
|
||||
wmsg = "Arguments ignored while initializing Components"
|
||||
warnings.warn(wmsg)
|
||||
else:
|
||||
self._initialize_components(num_components, initializer)
|
||||
|
||||
@property
|
||||
def num_components(self):
|
||||
return len(self._components)
|
||||
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
|
||||
def _initialize_components(self, num_components, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
_components = initializer.generate(num_components)
|
||||
self._register_components(_components)
|
||||
|
||||
def add_components(self,
|
||||
num=1,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
_components = torch.cat([self._components, initialized_components])
|
||||
else:
|
||||
_precheck_initializer(initializer)
|
||||
_new = initializer.generate(num)
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
|
||||
def remove_components(self, indices=None):
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
self._register_components(_components)
|
||||
return mask
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._components.detach()
|
||||
|
||||
def forward(self):
|
||||
return self._components
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(components): (shape: {tuple(self._components.shape)})"
|
||||
|
||||
|
||||
class LabeledComponents(Components):
|
||||
"""LabeledComponents generate a set of components and a set of labels.
|
||||
|
||||
Every Component has a label assigned.
|
||||
"""
|
||||
def __init__(self,
|
||||
distribution=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
components, component_labels = parse_data_arg(
|
||||
initialized_components)
|
||||
super().__init__(initialized_components=components)
|
||||
self._register_labels(component_labels)
|
||||
else:
|
||||
labels = get_labels_object(distribution)
|
||||
self.initial_distribution = labels.distribution
|
||||
_labels = labels.generate()
|
||||
super().__init__(len(_labels), initializer=initializer)
|
||||
self._register_labels(_labels)
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
clabels, counts = torch.unique(self._labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
return dict(zip(clabels.tolist(), counts.tolist()))
|
||||
|
||||
def _initialize_components(self, num_components, initializer):
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
_precheck_initializer(initializer)
|
||||
_components = initializer.generate(num_components,
|
||||
self.initial_distribution)
|
||||
self._register_components(_components)
|
||||
else:
|
||||
super()._initialize_components(num_components, initializer)
|
||||
|
||||
def add_components(self, distribution, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
|
||||
# Labels
|
||||
labels = get_labels_object(distribution)
|
||||
new_labels = labels.generate()
|
||||
_labels = torch.cat([self._labels, new_labels])
|
||||
self._register_labels(_labels)
|
||||
|
||||
# Components
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
_new = initializer.generate(len(new_labels), distribution)
|
||||
else:
|
||||
_new = initializer.generate(len(new_labels))
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
|
||||
def remove_components(self, indices=None):
|
||||
# Components
|
||||
mask = super().remove_components(indices)
|
||||
|
||||
# Labels
|
||||
_labels = self._labels[mask]
|
||||
self._register_labels(_labels)
|
||||
|
||||
@property
|
||||
def component_labels(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._labels.detach()
|
||||
|
||||
def forward(self):
|
||||
return super().forward(), self._labels
|
||||
|
||||
|
||||
class ReasoningComponents(Components):
|
||||
r"""ReasoningComponents generate a set of components and a set of reasoning matrices.
|
||||
|
||||
Every Component has a reasoning matrix assigned.
|
||||
|
||||
A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The
|
||||
first element is called positive reasoning :math:`p`, the second negative
|
||||
reasoning :math:`n`. A components can reason in favour (positive) of a
|
||||
class, against (negative) a class or not at all (neutral).
|
||||
|
||||
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
|
||||
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
|
||||
three element probability distribution.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
reasonings=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
components, reasonings = initialized_components
|
||||
|
||||
super().__init__(initialized_components=components)
|
||||
self.register_parameter("_reasonings", reasonings)
|
||||
else:
|
||||
self._initialize_reasonings(reasonings)
|
||||
super().__init__(len(self._reasonings), initializer=initializer)
|
||||
|
||||
def _initialize_reasonings(self, reasonings):
|
||||
if isinstance(reasonings, tuple):
|
||||
num_classes, num_components = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
||||
|
||||
_reasonings = reasonings.generate()
|
||||
self.register_parameter("_reasonings", _reasonings)
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Returns Reasoning Matrix.
|
||||
|
||||
Dimension NxCx2
|
||||
|
||||
"""
|
||||
return self._reasonings.detach()
|
||||
|
||||
def forward(self):
|
||||
return super().forward(), self._reasonings
|
233
prototorch/components/initializers.py
Normal file
233
prototorch/components/initializers.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""ProtoTroch Initializers."""
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
def parse_data_arg(data_arg):
|
||||
if isinstance(data_arg, Dataset):
|
||||
data_arg = DataLoader(data_arg, batch_size=len(data_arg))
|
||||
|
||||
if isinstance(data_arg, DataLoader):
|
||||
data = torch.tensor([])
|
||||
targets = torch.tensor([])
|
||||
for x, y in data_arg:
|
||||
data = torch.cat([data, x])
|
||||
targets = torch.cat([targets, y])
|
||||
else:
|
||||
data, targets = data_arg
|
||||
if not isinstance(data, torch.Tensor):
|
||||
wmsg = f"Converting data to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
data = torch.Tensor(data)
|
||||
if not isinstance(targets, torch.Tensor):
|
||||
wmsg = f"Converting targets to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
targets = torch.Tensor(targets)
|
||||
return data, targets
|
||||
|
||||
|
||||
def get_subinitializers(data, targets, clabels, subinit_type):
|
||||
initializers = dict()
|
||||
for clabel in clabels:
|
||||
class_data = data[targets == clabel]
|
||||
class_initializer = subinit_type(class_data)
|
||||
initializers[clabel] = (class_initializer)
|
||||
return initializers
|
||||
|
||||
|
||||
# Components
|
||||
class ComponentsInitializer(object):
|
||||
def generate(self, number_of_components):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class DimensionAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
if isinstance(dims, Iterable):
|
||||
self.components_dims = tuple(dims)
|
||||
else:
|
||||
self.components_dims = (dims, )
|
||||
|
||||
|
||||
class OnesInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, dims, scale=1.0):
|
||||
super().__init__(dims)
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.ones(gen_dims) * self.scale
|
||||
|
||||
|
||||
class ZerosInitializer(DimensionAwareInitializer):
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.zeros(gen_dims)
|
||||
|
||||
|
||||
class UniformInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0):
|
||||
super().__init__(dims)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.ones(gen_dims).uniform_(self.minimum,
|
||||
self.maximum) * self.scale
|
||||
|
||||
|
||||
class DataAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
super().__init__()
|
||||
self.data = data
|
||||
self.transform = transform
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class SelectionInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||
return self.transform(self.data[indices])
|
||||
|
||||
|
||||
class MeanInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
mean = torch.mean(self.data, dim=0)
|
||||
repeat_dim = [length] + [1] * len(mean.shape)
|
||||
return self.transform(mean.repeat(repeat_dim))
|
||||
|
||||
|
||||
class ClassAwareInitializer(DataAwareInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
data, targets = parse_data_arg(data)
|
||||
super().__init__(data, transform)
|
||||
self.targets = targets
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
def _get_samples_from_initializer(self, length, dist):
|
||||
if not dist:
|
||||
per_class = length // self.num_classes
|
||||
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
||||
if isinstance(dist, list):
|
||||
dist = dict(zip(self.clabels, dist))
|
||||
samples = [self.initializers[k].generate(n) for k, n in dist.items()]
|
||||
out = torch.vstack(samples)
|
||||
with torch.no_grad():
|
||||
out = self.transform(out)
|
||||
return out
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.targets
|
||||
|
||||
|
||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
def __init__(self, data, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
self.initializers = get_subinitializers(self.data, self.targets,
|
||||
self.clabels, MeanInitializer)
|
||||
|
||||
def generate(self, length, dist):
|
||||
samples = self._get_samples_from_initializer(length, dist)
|
||||
return samples
|
||||
|
||||
|
||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
def __init__(self, data, noise=None, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
self.noise = noise
|
||||
self.initializers = get_subinitializers(self.data, self.targets,
|
||||
self.clabels,
|
||||
SelectionInitializer)
|
||||
|
||||
def add_noise_v1(self, x):
|
||||
return x + self.noise
|
||||
|
||||
def add_noise_v2(self, x):
|
||||
"""Shifts some dimensions of the data randomly."""
|
||||
n1 = torch.rand_like(x)
|
||||
n2 = torch.rand_like(x)
|
||||
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
|
||||
return x + (self.noise * mask)
|
||||
|
||||
def generate(self, length, dist):
|
||||
samples = self._get_samples_from_initializer(length, dist)
|
||||
if self.noise is not None:
|
||||
samples = self.add_noise_v1(samples)
|
||||
return samples
|
||||
|
||||
|
||||
# Omega matrix
|
||||
class PcaInitializer(DataAwareInitializer):
|
||||
def generate(self, shape):
|
||||
(input_dim, latent_dim) = shape
|
||||
(_, eigVal, eigVec) = torch.pca_lowrank(self.data, q=latent_dim)
|
||||
return eigVec
|
||||
|
||||
|
||||
# Labels
|
||||
class LabelsInitializer:
|
||||
def generate(self):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class UnequalLabelsInitializer(LabelsInitializer):
|
||||
def __init__(self, dist, clabels=None):
|
||||
self.dist = dist
|
||||
self.clabels = clabels or range(len(self.dist))
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
return self.dist
|
||||
|
||||
def generate(self):
|
||||
targets = list(
|
||||
chain(*[[i] * n for i, n in zip(self.clabels, self.dist)]))
|
||||
return torch.LongTensor(targets)
|
||||
|
||||
|
||||
class EqualLabelsInitializer(LabelsInitializer):
|
||||
def __init__(self, classes, per_class):
|
||||
self.classes = classes
|
||||
self.per_class = per_class
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
return self.classes * [self.per_class]
|
||||
|
||||
def generate(self):
|
||||
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
||||
|
||||
|
||||
# Reasonings
|
||||
class ReasoningsInitializer:
|
||||
def generate(self, length):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||
def __init__(self, classes, length):
|
||||
self.classes = classes
|
||||
self.length = length
|
||||
|
||||
def generate(self):
|
||||
return torch.zeros((self.length, self.classes, 2))
|
||||
|
||||
|
||||
# Aliases
|
||||
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
||||
SMI = StratifiedMeanInitializer
|
||||
Random = RandomInitializer = UniformInitializer
|
||||
Zeros = ZerosInitializer
|
||||
Ones = OnesInitializer
|
||||
PCA = PcaInitializer
|
6
prototorch/datasets/__init__.py
Normal file
6
prototorch/datasets/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""ProtoTorch datasets."""
|
||||
|
||||
from .abstract import NumpyDataset
|
||||
from .sklearn import Blobs, Circles, Iris, Moons, Random
|
||||
from .spiral import Spiral
|
||||
from .tecator import Tecator
|
98
prototorch/datasets/abstract.py
Normal file
98
prototorch/datasets/abstract.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""ProtoTorch abstract dataset classes.
|
||||
|
||||
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
|
||||
|
||||
For the original code, see:
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||
def __init__(self, data, targets):
|
||||
self.data = torch.Tensor(data)
|
||||
self.targets = torch.LongTensor(targets)
|
||||
tensors = [self.data, self.targets]
|
||||
super().__init__(*tensors)
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
"""Abstract dataset class to be inherited."""
|
||||
|
||||
_repr_indent = 2
|
||||
|
||||
def __init__(self, root):
|
||||
if isinstance(root, torch._six.string_classes):
|
||||
root = os.path.expanduser(root)
|
||||
self.root = root
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProtoDataset(Dataset):
|
||||
"""Abstract dataset class to be inherited."""
|
||||
|
||||
training_file = "training.pt"
|
||||
test_file = "test.pt"
|
||||
|
||||
def __init__(self, root, train=True, download=True, verbose=True):
|
||||
super().__init__(root)
|
||||
self.train = train # training set or test set
|
||||
self.verbose = verbose
|
||||
|
||||
if download:
|
||||
self._download()
|
||||
|
||||
if not self._check_exists():
|
||||
raise RuntimeError("Dataset not found. "
|
||||
"You can use download=True to download it")
|
||||
|
||||
data_file = self.training_file if self.train else self.test_file
|
||||
|
||||
self.data, self.targets = torch.load(
|
||||
os.path.join(self.processed_folder, data_file))
|
||||
|
||||
@property
|
||||
def raw_folder(self):
|
||||
return os.path.join(self.root, self.__class__.__name__, "raw")
|
||||
|
||||
@property
|
||||
def processed_folder(self):
|
||||
return os.path.join(self.root, self.__class__.__name__, "processed")
|
||||
|
||||
@property
|
||||
def class_to_idx(self):
|
||||
return {_class: i for i, _class in enumerate(self.classes)}
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(
|
||||
os.path.join(
|
||||
self.processed_folder, self.training_file)) and os.path.exists(
|
||||
os.path.join(self.processed_folder, self.test_file))
|
||||
|
||||
def __repr__(self):
|
||||
head = "Dataset " + self.__class__.__name__
|
||||
body = ["Number of datapoints: {}".format(self.__len__())]
|
||||
if self.root is not None:
|
||||
body.append("Root location: {}".format(self.root))
|
||||
body += self.extra_repr().splitlines()
|
||||
lines = [head] + [" " * self._repr_indent + line for line in body]
|
||||
return "\n".join(lines)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"Split: {'Train' if self.train is True else 'Test'}"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _download(self):
|
||||
raise NotImplementedError
|
137
prototorch/datasets/sklearn.py
Normal file
137
prototorch/datasets/sklearn.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Thin wrappers for a few scikit-learn datasets.
|
||||
|
||||
URL:
|
||||
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets
|
||||
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Sequence, Union
|
||||
|
||||
from sklearn.datasets import (load_iris, make_blobs, make_circles,
|
||||
make_classification, make_moons)
|
||||
|
||||
from prototorch.datasets.abstract import NumpyDataset
|
||||
|
||||
|
||||
class Iris(NumpyDataset):
|
||||
"""Iris Dataset by Ronald Fisher introduced in 1936.
|
||||
|
||||
The dataset contains four measurements from flowers of three species of iris.
|
||||
|
||||
.. list-table:: Iris
|
||||
:header-rows: 1
|
||||
|
||||
* - dimensions
|
||||
- classes
|
||||
- training size
|
||||
- validation size
|
||||
- test size
|
||||
* - 4
|
||||
- 3
|
||||
- 150
|
||||
- 0
|
||||
- 0
|
||||
|
||||
:param dims: select a subset of dimensions
|
||||
"""
|
||||
def __init__(self, dims: Sequence[int] = None):
|
||||
x, y = load_iris(return_X_y=True)
|
||||
if dims:
|
||||
x = x[:, dims]
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
class Blobs(NumpyDataset):
|
||||
"""Generate isotropic Gaussian blobs for clustering.
|
||||
|
||||
Read more at
|
||||
https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 300,
|
||||
num_features: int = 2,
|
||||
seed: Union[None, int] = 0):
|
||||
x, y = make_blobs(num_samples,
|
||||
num_features,
|
||||
centers=None,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
class Random(NumpyDataset):
|
||||
"""Generate a random n-class classification problem.
|
||||
|
||||
Read more at
|
||||
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html.
|
||||
|
||||
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 300,
|
||||
num_features: int = 2,
|
||||
num_classes: int = 2,
|
||||
num_clusters: int = 2,
|
||||
num_informative: Union[None, int] = None,
|
||||
separation: float = 1.0,
|
||||
seed: Union[None, int] = 0):
|
||||
if not num_informative:
|
||||
import math
|
||||
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
||||
if num_features < num_informative:
|
||||
warnings.warn("Generating more features than requested.")
|
||||
num_features = num_informative
|
||||
x, y = make_classification(num_samples,
|
||||
num_features,
|
||||
n_informative=num_informative,
|
||||
n_redundant=0,
|
||||
n_classes=num_classes,
|
||||
n_clusters_per_class=num_clusters,
|
||||
class_sep=separation,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
class Circles(NumpyDataset):
|
||||
"""Make a large circle containing a smaller circle in 2D.
|
||||
|
||||
A simple toy dataset to visualize clustering and classification algorithms.
|
||||
|
||||
Read more at
|
||||
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 300,
|
||||
noise: float = 0.3,
|
||||
factor: float = 0.8,
|
||||
seed: Union[None, int] = 0):
|
||||
x, y = make_circles(num_samples,
|
||||
noise=noise,
|
||||
factor=factor,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
class Moons(NumpyDataset):
|
||||
"""Make two interleaving half circles.
|
||||
|
||||
A simple toy dataset to visualize clustering and classification algorithms.
|
||||
|
||||
Read more at
|
||||
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 300,
|
||||
noise: float = 0.3,
|
||||
seed: Union[None, int] = 0):
|
||||
x, y = make_moons(num_samples,
|
||||
noise=noise,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
super().__init__(x, y)
|
57
prototorch/datasets/spiral.py
Normal file
57
prototorch/datasets/spiral.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Spiral dataset for binary classification."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def make_spiral(num_samples=500, noise=0.3):
|
||||
"""Generates the Spiral Dataset.
|
||||
|
||||
For use in Prototorch use `prototorch.datasets.Spiral` instead.
|
||||
"""
|
||||
def get_samples(n, delta_t):
|
||||
points = []
|
||||
for i in range(n):
|
||||
r = i / num_samples * 5
|
||||
t = 1.75 * i / n * 2 * np.pi + delta_t
|
||||
x = r * np.sin(t) + np.random.rand(1) * noise
|
||||
y = r * np.cos(t) + np.random.rand(1) * noise
|
||||
points.append([x, y])
|
||||
return points
|
||||
|
||||
n = num_samples // 2
|
||||
positive = get_samples(n=n, delta_t=0)
|
||||
negative = get_samples(n=n, delta_t=np.pi)
|
||||
x = np.concatenate(
|
||||
[np.array(positive).reshape(n, -1),
|
||||
np.array(negative).reshape(n, -1)],
|
||||
axis=0)
|
||||
y = np.concatenate([np.zeros(n), np.ones(n)])
|
||||
return x, y
|
||||
|
||||
|
||||
class Spiral(torch.utils.data.TensorDataset):
|
||||
"""Spiral dataset for binary classification.
|
||||
|
||||
This datasets consists of two spirals of two different classes.
|
||||
|
||||
.. list-table:: Spiral
|
||||
:header-rows: 1
|
||||
|
||||
* - dimensions
|
||||
- classes
|
||||
- training size
|
||||
- validation size
|
||||
- test size
|
||||
* - 2
|
||||
- 2
|
||||
- num_samples
|
||||
- 0
|
||||
- 0
|
||||
|
||||
:param num_samples: number of random samples
|
||||
:param noise: noise added to the spirals
|
||||
"""
|
||||
def __init__(self, num_samples: int = 500, noise: float = 0.3):
|
||||
x, y = make_spiral(num_samples, noise)
|
||||
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
121
prototorch/datasets/tecator.py
Normal file
121
prototorch/datasets/tecator.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Tecator dataset for classification.
|
||||
|
||||
URL:
|
||||
http://lib.stat.cmu.edu/datasets/tecator
|
||||
|
||||
LICENCE / TERMS / COPYRIGHT:
|
||||
This is the Tecator data set: The task is to predict the fat content
|
||||
of a meat sample on the basis of its near infrared absorbance spectrum.
|
||||
-------------------------------------------------------------------------
|
||||
1. Statement of permission from Tecator (the original data source)
|
||||
|
||||
These data are recorded on a Tecator Infratec Food and Feed Analyzer
|
||||
working in the wavelength range 850 - 1050 nm by the Near Infrared
|
||||
Transmission (NIT) principle. Each sample contains finely chopped pure
|
||||
meat with different moisture, fat and protein contents.
|
||||
|
||||
If results from these data are used in a publication we want you to
|
||||
mention the instrument and company name (Tecator) in the publication.
|
||||
In addition, please send a preprint of your article to
|
||||
|
||||
Karin Thente, Tecator AB,
|
||||
Box 70, S-263 21 Hoganas, Sweden
|
||||
|
||||
The data are available in the public domain with no responsability from
|
||||
the original data source. The data can be redistributed as long as this
|
||||
permission note is attached.
|
||||
|
||||
For more information about the instrument - call Perstorp Analytical's
|
||||
representative in your area.
|
||||
|
||||
Description:
|
||||
For each meat sample the data consists of a 100 channel spectrum of
|
||||
absorbances and the contents of moisture (water), fat and protein.
|
||||
The absorbance is -log10 of the transmittance
|
||||
measured by the spectrometer. The three contents, measured in percent,
|
||||
are determined by analytic chemistry.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.datasets.utils import download_file_from_google_drive
|
||||
|
||||
from prototorch.datasets.abstract import ProtoDataset
|
||||
|
||||
|
||||
class Tecator(ProtoDataset):
|
||||
"""
|
||||
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__ for classification.
|
||||
|
||||
The dataset contains wavelength measurements of meat.
|
||||
|
||||
.. list-table:: Tecator
|
||||
:header-rows: 1
|
||||
|
||||
* - dimensions
|
||||
- classes
|
||||
- training size
|
||||
- validation size
|
||||
- test size
|
||||
* - 100
|
||||
- 2
|
||||
- 129
|
||||
- 43
|
||||
- 43
|
||||
"""
|
||||
|
||||
_resources = [
|
||||
("1P9WIYnyxFPh6f1vqAbnKfK8oYmUgyV83",
|
||||
"ba5607c580d0f91bb27dc29d13c2f8df"),
|
||||
] # (google_storage_id, md5hash)
|
||||
classes = ["0 - low_fat", "1 - high_fat"]
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], int(self.targets[index])
|
||||
return img, target
|
||||
|
||||
def _download(self):
|
||||
"""Download the data if it doesn't exist in already."""
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
if self.verbose:
|
||||
print("Making directories...")
|
||||
os.makedirs(self.raw_folder, exist_ok=True)
|
||||
os.makedirs(self.processed_folder, exist_ok=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Downloading...")
|
||||
for fileid, md5 in self._resources:
|
||||
filename = "tecator.npz"
|
||||
download_file_from_google_drive(fileid,
|
||||
root=self.raw_folder,
|
||||
filename=filename,
|
||||
md5=md5)
|
||||
|
||||
if self.verbose:
|
||||
print("Processing...")
|
||||
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
||||
allow_pickle=False) as f:
|
||||
x_train, y_train = f["x_train"], f["y_train"]
|
||||
x_test, y_test = f["x_test"], f["y_test"]
|
||||
training_set = [
|
||||
torch.Tensor(x_train),
|
||||
torch.LongTensor(y_train),
|
||||
]
|
||||
test_set = [
|
||||
torch.Tensor(x_test),
|
||||
torch.LongTensor(y_test),
|
||||
]
|
||||
|
||||
with open(os.path.join(self.processed_folder, self.training_file),
|
||||
"wb") as f:
|
||||
torch.save(training_set, f)
|
||||
with open(os.path.join(self.processed_folder, self.test_file),
|
||||
"wb") as f:
|
||||
torch.save(test_set, f)
|
||||
|
||||
if self.verbose:
|
||||
print("Done!")
|
@@ -0,0 +1,5 @@
|
||||
"""ProtoTorch functions."""
|
||||
|
||||
from .activations import identity, sigmoid_beta, swish_beta
|
||||
from .competitions import knnc, wtac
|
||||
from .pooling import *
|
||||
|
@@ -5,44 +5,58 @@ import torch
|
||||
ACTIVATIONS = dict()
|
||||
|
||||
|
||||
def register_activation(func):
|
||||
ACTIVATIONS[func.__name__] = func
|
||||
return func
|
||||
def register_activation(fn):
|
||||
"""Add the activation function to the registry."""
|
||||
name = fn.__name__
|
||||
ACTIVATIONS[name] = fn
|
||||
return fn
|
||||
|
||||
|
||||
@register_activation
|
||||
def identity(input, **kwargs):
|
||||
""":math:`f(x) = x`"""
|
||||
return input
|
||||
def identity(x, beta=0.0):
|
||||
"""Identity activation function.
|
||||
|
||||
|
||||
@register_activation
|
||||
def sigmoid_beta(input, beta=10):
|
||||
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
|
||||
Definition:
|
||||
:math:`f(x) = x`
|
||||
|
||||
Keyword Arguments:
|
||||
beta (float): Parameter :math:`\\beta`
|
||||
beta (`float`): Ignored.
|
||||
"""
|
||||
out = torch.reciprocal(1.0 + torch.exp(-beta * input))
|
||||
return x
|
||||
|
||||
|
||||
@register_activation
|
||||
def sigmoid_beta(x, beta=10.0):
|
||||
r"""Sigmoid activation function with scaling.
|
||||
|
||||
Definition:
|
||||
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
|
||||
|
||||
Keyword Arguments:
|
||||
beta (`float`): Scaling parameter :math:`\beta`
|
||||
"""
|
||||
out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x))
|
||||
return out
|
||||
|
||||
|
||||
@register_activation
|
||||
def swish_beta(input, beta=10):
|
||||
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
|
||||
def swish_beta(x, beta=10.0):
|
||||
r"""Swish activation function with scaling.
|
||||
|
||||
Definition:
|
||||
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
|
||||
|
||||
Keyword Arguments:
|
||||
beta (float): Parameter :math:`\\beta`
|
||||
beta (`float`): Scaling parameter :math:`\beta`
|
||||
"""
|
||||
out = input * sigmoid_beta(input, beta=beta)
|
||||
out = x * sigmoid_beta(x, beta=beta)
|
||||
return out
|
||||
|
||||
|
||||
def get_activation(funcname):
|
||||
"""Deserialize the activation function."""
|
||||
if callable(funcname):
|
||||
return funcname
|
||||
else:
|
||||
if funcname in ACTIVATIONS:
|
||||
return ACTIVATIONS.get(funcname)
|
||||
else:
|
||||
raise NameError(f'Activation {funcname} was not found.')
|
||||
if funcname in ACTIVATIONS:
|
||||
return ACTIVATIONS.get(funcname)
|
||||
raise NameError(f"Activation {funcname} was not found.")
|
||||
|
@@ -3,13 +3,26 @@
|
||||
import torch
|
||||
|
||||
|
||||
def wtac(distances, labels):
|
||||
def wtac(distances: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.LongTensor):
|
||||
"""Winner-Takes-All-Competition.
|
||||
|
||||
Returns the labels corresponding to the winners.
|
||||
|
||||
"""
|
||||
winning_indices = torch.min(distances, dim=1).indices
|
||||
winning_labels = labels[winning_indices].squeeze()
|
||||
return winning_labels
|
||||
|
||||
|
||||
def knnc(distances, labels, k):
|
||||
def knnc(distances: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
k: int = 1) -> (torch.LongTensor):
|
||||
"""K-Nearest-Neighbors-Competition.
|
||||
|
||||
Returns the labels corresponding to the winners.
|
||||
|
||||
"""
|
||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
||||
winning_labels = labels[winning_indices].squeeze()
|
||||
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
||||
return winning_labels
|
||||
|
@@ -1,14 +1,21 @@
|
||||
"""ProtoTorch distance functions."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||
equal_int_shape, get_flat)
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
"""Compute the squared Euclidean distance between :math:`x` and :math:`y`.
|
||||
r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
||||
|
||||
**Alias:**
|
||||
``prototorch.functions.distances.sed``
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
expanded_x = x.unsqueeze(dim=1)
|
||||
batchwise_difference = y - expanded_x
|
||||
differences_raised = torch.pow(batchwise_difference, 2)
|
||||
@@ -17,42 +24,56 @@ def squared_euclidean_distance(x, y):
|
||||
|
||||
|
||||
def euclidean_distance(x, y):
|
||||
"""Compute the Euclidean distance between :math:`x` and :math:`y`.
|
||||
r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
||||
|
||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
distances_raised = squared_euclidean_distance(x, y)
|
||||
distances = torch.sqrt(distances_raised)
|
||||
return distances
|
||||
|
||||
|
||||
def lpnorm_distance(x, y, p):
|
||||
"""Compute :math:`{\\langle x, y \\rangle}_p`.
|
||||
def euclidean_distance_v2(x, y):
|
||||
x, y = get_flat(x, y)
|
||||
diff = y - x.unsqueeze(1)
|
||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||
# batch diagonal. See:
|
||||
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
||||
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
||||
# print(f"{distances.shape=}") # (nx, ny)
|
||||
return distances
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
|
||||
def lpnorm_distance(x, y, p):
|
||||
r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
||||
Also known as Minkowski distance.
|
||||
|
||||
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
||||
|
||||
Calls ``torch.cdist``
|
||||
|
||||
:param p: p parameter of the lp norm
|
||||
"""
|
||||
# # DEPRECATED in favor of torch.cdist
|
||||
# expanded_x = x.unsqueeze(dim=1)
|
||||
# batchwise_difference = y - expanded_x
|
||||
# differences_raised = torch.pow(batchwise_difference, p)
|
||||
# distances_raised = torch.sum(differences_raised, axis=2)
|
||||
# distances = torch.pow(distances_raised, 1.0 / p)
|
||||
# return distances
|
||||
x, y = get_flat(x, y)
|
||||
distances = torch.cdist(x, y, p=p)
|
||||
return distances
|
||||
|
||||
|
||||
def omega_distance(x, y, omega):
|
||||
"""Omega distance.
|
||||
r"""Omega distance.
|
||||
|
||||
Compute :math:`{\\langle \\Omega x, \\Omega y \\rangle}_p`
|
||||
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
Expected dimension of omega is 2.
|
||||
:param `torch.tensor` omega: Two dimensional matrix
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
projected_x = x @ omega
|
||||
projected_y = y @ omega
|
||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||
@@ -60,14 +81,13 @@ def omega_distance(x, y, omega):
|
||||
|
||||
|
||||
def lomega_distance(x, y, omegas):
|
||||
"""Localized Omega distance.
|
||||
r"""Localized Omega distance.
|
||||
|
||||
Compute :math:`{\\langle \\Omega_k x, \\Omega_k y_k \\rangle}_p`
|
||||
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
Expected dimension of omegas is 3.
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
projected_x = x @ omegas
|
||||
projected_y = torch.diagonal(y @ omegas).T
|
||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||
@@ -76,3 +96,164 @@ def lomega_distance(x, y, omegas):
|
||||
distances = torch.sum(differences_squared, dim=2)
|
||||
distances = distances.permute(1, 0)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
||||
r"""Computes an euclidean distances matrix given two distinct vectors.
|
||||
last dimension must be the vector dimension!
|
||||
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
||||
|
||||
- ``x.shape = (number_of_x_vectors, vector_dim)``
|
||||
- ``y.shape = (number_of_y_vectors, vector_dim)``
|
||||
|
||||
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
||||
"""
|
||||
for tensor in [x, y]:
|
||||
if tensor.ndim != 2:
|
||||
raise ValueError(
|
||||
"The tensor dimension must be two. You provide: tensor.ndim=" +
|
||||
str(tensor.ndim) + ".")
|
||||
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||
raise ValueError(
|
||||
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
||||
+ str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" +
|
||||
str(tuple(y.shape)[1]) + ".")
|
||||
|
||||
y = torch.transpose(y)
|
||||
|
||||
diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
|
||||
torch.sum(y**2, axis=0, keepdims=True))
|
||||
|
||||
if not squared:
|
||||
if epsilon == 0:
|
||||
diss = torch.sqrt(diss)
|
||||
else:
|
||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||
|
||||
return diss
|
||||
|
||||
|
||||
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
||||
r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
|
||||
|
||||
For more info about Tangen distances see
|
||||
|
||||
DOI:10.1109/IJCNN.2016.7727534.
|
||||
|
||||
The subspaces is always assumed as transposed and must be orthogonal!
|
||||
For local non sparse signals subspaces must be provided!
|
||||
|
||||
- shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
- shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
- shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||
|
||||
subspace should be orthogonalized
|
||||
Pytorch implementation of Sascha Saralajew's tensorflow code.
|
||||
Translation by Christoph Raab
|
||||
"""
|
||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
||||
subspace_int_shape = tuple(subspaces.shape)
|
||||
|
||||
# check if the shapes are correct
|
||||
_check_shapes(signal_int_shape, proto_int_shape)
|
||||
|
||||
atom_axes = list(range(3, len(signal_int_shape)))
|
||||
# for sparse signals, we use the memory efficient implementation
|
||||
if signal_int_shape[1] == 1:
|
||||
signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
|
||||
|
||||
if len(atom_axes) > 1:
|
||||
protos = torch.reshape(protos, [proto_shape[0], -1])
|
||||
|
||||
if subspaces.ndim == 2:
|
||||
# clean solution without map if the matrix_scope is global
|
||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||
subspaces, torch.transpose(subspaces))
|
||||
|
||||
projected_signals = torch.dot(signals, projectors)
|
||||
projected_protos = torch.dot(protos, projectors)
|
||||
|
||||
diss = euclidean_distance_matrix(projected_signals,
|
||||
projected_protos,
|
||||
squared=squared,
|
||||
epsilon=epsilon)
|
||||
|
||||
diss = torch.reshape(
|
||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||
|
||||
return torch.permute(diss, [0, 2, 1])
|
||||
|
||||
else:
|
||||
|
||||
# no solution without map possible --> memory efficient but slow!
|
||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||
subspaces,
|
||||
subspaces) # K.batch_dot(subspaces, subspaces, [2, 2])
|
||||
|
||||
projected_protos = (protos @ subspaces
|
||||
).T # K.batch_dot(projectors, protos, [1, 1]))
|
||||
|
||||
def projected_norm(projector):
|
||||
return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
||||
|
||||
diss = (torch.transpose(map(projected_norm, projectors)) -
|
||||
2 * torch.dot(signals, projected_protos) +
|
||||
torch.sum(projected_protos**2, axis=0, keepdims=True))
|
||||
|
||||
if not squared:
|
||||
if epsilon == 0:
|
||||
diss = torch.sqrt(diss)
|
||||
else:
|
||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||
|
||||
diss = torch.reshape(
|
||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||
|
||||
return torch.permute(diss, [0, 2, 1])
|
||||
|
||||
else:
|
||||
signals = signals.permute([0, 2, 1] + atom_axes)
|
||||
|
||||
diff = signals - protos
|
||||
|
||||
# global tangent space
|
||||
if subspaces.ndim == 2:
|
||||
# Scope Projectors
|
||||
projectors = subspaces #
|
||||
|
||||
# Scope: Tangentspace Projections
|
||||
diff = torch.reshape(
|
||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||
projected_diff = diff @ projectors
|
||||
projected_diff = torch.reshape(
|
||||
projected_diff,
|
||||
(signal_shape[0], signal_shape[2], signal_shape[1]) +
|
||||
signal_shape[3:],
|
||||
)
|
||||
|
||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
return diss.permute([0, 2, 1])
|
||||
|
||||
# local tangent spaces
|
||||
else:
|
||||
# Scope: Calculate Projectors
|
||||
projectors = subspaces
|
||||
|
||||
# Scope: Tangentspace Projections
|
||||
diff = torch.reshape(
|
||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||
diff = diff.permute([1, 0, 2])
|
||||
projected_diff = torch.bmm(diff, projectors)
|
||||
projected_diff = torch.reshape(
|
||||
projected_diff,
|
||||
(signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||
signal_shape[3:],
|
||||
)
|
||||
|
||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
return diss.permute([1, 0, 2]).squeeze(-1)
|
||||
|
||||
|
||||
# Aliases
|
||||
sed = squared_euclidean_distance
|
||||
|
94
prototorch/functions/helper.py
Normal file
94
prototorch/functions/helper.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_flat(*args):
|
||||
rv = [x.view(x.size(0), -1) for x in args]
|
||||
return rv
|
||||
|
||||
|
||||
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
||||
"""Computes the accuracy of a prototype based model.
|
||||
via Winner-Takes-All rule.
|
||||
Requirement:
|
||||
y_pred.shape == y_true.shape
|
||||
unique(y_pred) in plabels
|
||||
"""
|
||||
with torch.no_grad():
|
||||
idx = torch.argmin(y_pred, axis=1)
|
||||
return torch.true_divide(torch.sum(y_true == plabels[idx]),
|
||||
len(y_pred)) * 100
|
||||
|
||||
|
||||
def predict_label(y_pred, plabels):
|
||||
r""" Predicts labels given a prediction of a prototype based model.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return plabels[torch.argmin(y_pred, 1)]
|
||||
|
||||
|
||||
def mixed_shape(inputs):
|
||||
if not torch.is_tensor(inputs):
|
||||
raise ValueError("Input must be a tensor.")
|
||||
else:
|
||||
int_shape = list(inputs.shape)
|
||||
# sometimes int_shape returns mixed integer types
|
||||
int_shape = [int(i) if i is not None else i for i in int_shape]
|
||||
tensor_shape = inputs.shape
|
||||
|
||||
for i, s in enumerate(int_shape):
|
||||
if s is None:
|
||||
int_shape[i] = tensor_shape[i]
|
||||
return tuple(int_shape)
|
||||
|
||||
|
||||
def equal_int_shape(shape_1, shape_2):
|
||||
if not isinstance(shape_1,
|
||||
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
||||
raise ValueError("Input shapes must list or tuple.")
|
||||
for shape in [shape_1, shape_2]:
|
||||
if not all([isinstance(x, int) or x is None for x in shape]):
|
||||
raise ValueError(
|
||||
"Input shapes must be list or tuple of int and None values.")
|
||||
|
||||
if len(shape_1) != len(shape_2):
|
||||
return False
|
||||
else:
|
||||
for axis, value in enumerate(shape_1):
|
||||
if value is not None and shape_2[axis] not in {value, None}:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_shapes(signal_int_shape, proto_int_shape):
|
||||
if len(signal_int_shape) < 4:
|
||||
raise ValueError(
|
||||
"The number of signal dimensions must be >=4. You provide: " +
|
||||
str(len(signal_int_shape)))
|
||||
|
||||
if len(proto_int_shape) < 2:
|
||||
raise ValueError(
|
||||
"The number of proto dimensions must be >=2. You provide: " +
|
||||
str(len(proto_int_shape)))
|
||||
|
||||
if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]):
|
||||
raise ValueError(
|
||||
"The atom shape of signals must be equal protos. You provide: signals.shape[3:]="
|
||||
+ str(signal_int_shape[3:]) + " != protos.shape[1:]=" +
|
||||
str(proto_int_shape[1:]))
|
||||
|
||||
# not a sparse signal
|
||||
if signal_int_shape[1] != 1:
|
||||
if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]):
|
||||
raise ValueError(
|
||||
"If the signal is not sparse, the number of prototypes must be equal in signals and "
|
||||
"protos. You provide: " + str(signal_int_shape[1]) + " != " +
|
||||
str(proto_int_shape[0]))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _int_and_mixed_shape(tensor):
|
||||
shape = mixed_shape(tensor)
|
||||
int_shape = tuple(i if isinstance(i, int) else None for i in shape)
|
||||
|
||||
return shape, int_shape
|
@@ -7,87 +7,101 @@ import torch
|
||||
INITIALIZERS = dict()
|
||||
|
||||
|
||||
def register_initializer(func):
|
||||
INITIALIZERS[func.__name__] = func
|
||||
return func
|
||||
def register_initializer(function):
|
||||
"""Add the initializer to the registry."""
|
||||
INITIALIZERS[function.__name__] = function
|
||||
return function
|
||||
|
||||
|
||||
def labels_from(distribution):
|
||||
def labels_from(distribution, one_hot=True):
|
||||
"""Takes a distribution tensor and returns a labels tensor."""
|
||||
nclasses = distribution.shape[0]
|
||||
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
||||
num_classes = distribution.shape[0]
|
||||
llist = [[i] * n for i, n in zip(range(num_classes), distribution)]
|
||||
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
||||
labels = list(chain(*llist)) # flatten using itertools.chain
|
||||
return torch.tensor(labels, requires_grad=False)
|
||||
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
||||
plabels = torch.tensor(flat_llist, requires_grad=False)
|
||||
if one_hot:
|
||||
return torch.eye(num_classes)[plabels]
|
||||
return plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def ones(x_train, y_train, prototype_distribution):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.ones(nprotos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution)
|
||||
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.ones(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def zeros(x_train, y_train, prototype_distribution):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution)
|
||||
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.zeros(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def rand(x_train, y_train, prototype_distribution):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.rand(nprotos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution)
|
||||
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.rand(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def randn(x_train, y_train, prototype_distribution):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.randn(nprotos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution)
|
||||
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.randn(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def stratified_mean(x_train, y_train, prototype_distribution):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(nprotos, pdim)
|
||||
plabels = labels_from(prototype_distribution)
|
||||
for i, l in enumerate(plabels):
|
||||
xl = x_train[y_train == l]
|
||||
protos = torch.empty(num_protos, pdim)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
for i, label in enumerate(plabels):
|
||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||
if one_hot:
|
||||
num_classes = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
xl = x_train[matcher]
|
||||
mean_xl = torch.mean(xl, dim=0)
|
||||
protos[i] = mean_xl
|
||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def stratified_random(x_train, y_train, prototype_distribution):
|
||||
gen = torch.manual_seed(torch.initial_seed())
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
def stratified_random(x_train,
|
||||
y_train,
|
||||
prototype_distribution,
|
||||
one_hot=True,
|
||||
epsilon=1e-7):
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(nprotos, pdim)
|
||||
plabels = labels_from(prototype_distribution)
|
||||
for i, l in enumerate(plabels):
|
||||
xl = x_train[y_train == l]
|
||||
rand_index = torch.zeros(1).long().random_(0,
|
||||
xl.shape[1] - 1,
|
||||
generator=gen)
|
||||
protos = torch.empty(num_protos, pdim)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
for i, label in enumerate(plabels):
|
||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||
if one_hot:
|
||||
num_classes = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
xl = x_train[matcher]
|
||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||
random_xl = xl[rand_index]
|
||||
protos[i] = random_xl
|
||||
protos[i] = random_xl + epsilon
|
||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
def get_initializer(funcname):
|
||||
"""Deserialize the initializer."""
|
||||
if callable(funcname):
|
||||
return funcname
|
||||
else:
|
||||
if funcname in INITIALIZERS:
|
||||
return INITIALIZERS.get(funcname)
|
||||
else:
|
||||
raise NameError(f'Initializer {funcname} was not found.')
|
||||
if funcname in INITIALIZERS:
|
||||
return INITIALIZERS.get(funcname)
|
||||
raise NameError(f"Initializer {funcname} was not found.")
|
||||
|
@@ -3,23 +3,92 @@
|
||||
import torch
|
||||
|
||||
|
||||
def glvq_loss(distances, target_labels, prototype_labels):
|
||||
"""GLVQ loss function with support for one-hot labels."""
|
||||
matcher = torch.eq(target_labels.unsqueeze(dim=1), prototype_labels)
|
||||
if prototype_labels.ndim == 2:
|
||||
def _get_matcher(targets, labels):
|
||||
"""Returns a boolean tensor."""
|
||||
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
nclasses = target_labels.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
num_classes = targets.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
return matcher
|
||||
|
||||
|
||||
def _get_dp_dm(distances, targets, plabels, with_indices=False):
|
||||
"""Returns the d+ and d- values for a batch of distances."""
|
||||
matcher = _get_matcher(targets, plabels)
|
||||
not_matcher = torch.bitwise_not(matcher)
|
||||
|
||||
dplus_criterion = distances * matcher > 0.0
|
||||
dminus_criterion = distances * not_matcher > 0.0
|
||||
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||
d_matching = torch.where(matcher, distances, inf)
|
||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||
dp = torch.min(d_matching, dim=-1, keepdim=True)
|
||||
dm = torch.min(d_unmatching, dim=-1, keepdim=True)
|
||||
if with_indices:
|
||||
return dp, dm
|
||||
return dp.values, dm.values
|
||||
|
||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
||||
distances_to_wpluses = torch.where(dplus_criterion, distances, inf)
|
||||
distances_to_wminuses = torch.where(dminus_criterion, distances, inf)
|
||||
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
|
||||
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
|
||||
|
||||
mu = (dpluses - dminuses) / (dpluses + dminuses)
|
||||
def glvq_loss(distances, target_labels, prototype_labels):
|
||||
"""GLVQ loss function with support for one-hot labels."""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = (dp - dm) / (dp + dm)
|
||||
return mu
|
||||
|
||||
|
||||
def lvq1_loss(distances, target_labels, prototype_labels):
|
||||
"""LVQ1 loss function with support for one-hot labels.
|
||||
|
||||
See Section 4 [Sado&Yamada]
|
||||
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||
"""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = dp
|
||||
mu[dp > dm] = -dm[dp > dm]
|
||||
return mu
|
||||
|
||||
|
||||
def lvq21_loss(distances, target_labels, prototype_labels):
|
||||
"""LVQ2.1 loss function with support for one-hot labels.
|
||||
|
||||
See Section 4 [Sado&Yamada]
|
||||
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||
"""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = dp - dm
|
||||
|
||||
return mu
|
||||
|
||||
|
||||
# Probabilistic
|
||||
def _get_class_probabilities(probabilities, targets, prototype_labels):
|
||||
# Create Label Mapping
|
||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||
|
||||
target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist())))
|
||||
|
||||
whole = probabilities.sum(dim=1)
|
||||
correct = probabilities[torch.arange(len(probabilities)), target_indices]
|
||||
wrong = whole - correct
|
||||
|
||||
return whole, correct, wrong
|
||||
|
||||
|
||||
def nllr_loss(probabilities, targets, prototype_labels):
|
||||
"""Compute the Negative Log-Likelihood Ratio loss."""
|
||||
_, correct, wrong = _get_class_probabilities(probabilities, targets,
|
||||
prototype_labels)
|
||||
|
||||
likelihood = correct / wrong
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return -1.0 * log_likelihood
|
||||
|
||||
|
||||
def rslvq_loss(probabilities, targets, prototype_labels):
|
||||
"""Compute the Robust Soft Learning Vector Quantization (RSLVQ) loss."""
|
||||
whole, correct, _ = _get_class_probabilities(probabilities, targets,
|
||||
prototype_labels)
|
||||
|
||||
likelihood = correct / whole
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return -1.0 * log_likelihood
|
||||
|
35
prototorch/functions/normalization.py
Normal file
35
prototorch/functions/normalization.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def orthogonalization(tensors):
|
||||
r""" Orthogonalization of a given tensor via polar decomposition.
|
||||
"""
|
||||
u, _, v = torch.svd(tensors, compute_uv=True)
|
||||
u_shape = tuple(list(u.shape))
|
||||
v_shape = tuple(list(v.shape))
|
||||
|
||||
# reshape to (num x N x M)
|
||||
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
|
||||
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
|
||||
|
||||
out = u @ v.permute([0, 2, 1])
|
||||
|
||||
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def trace_normalization(tensors):
|
||||
r""" Trace normalization
|
||||
"""
|
||||
epsilon = torch.tensor([1e-10], dtype=torch.float64)
|
||||
# Scope trace_normalization
|
||||
constant = torch.trace(tensors)
|
||||
|
||||
if epsilon != 0:
|
||||
constant = torch.max(constant, epsilon)
|
||||
|
||||
return tensors / constant
|
80
prototorch/functions/pooling.py
Normal file
80
prototorch/functions/pooling.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""ProtoTorch pooling functions."""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def stratify_with(values: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
fn: Callable,
|
||||
fill_value: float = 0.0) -> (torch.Tensor):
|
||||
"""Apply an arbitrary stratification strategy on the columns on `values`.
|
||||
|
||||
The outputs correspond to sorted labels.
|
||||
"""
|
||||
clabels = torch.unique(labels, dim=0, sorted=True)
|
||||
num_classes = clabels.size()[0]
|
||||
if values.size()[1] == num_classes:
|
||||
# skip if stratification is trivial
|
||||
return values
|
||||
batch_size = values.size()[0]
|
||||
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
||||
filler = torch.full_like(values.T, fill_value=fill_value)
|
||||
for i, cl in enumerate(clabels):
|
||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
cdists = torch.where(matcher, values.T, filler).T
|
||||
winning_values[i] = fn(cdists)
|
||||
if labels.ndim == 2:
|
||||
# Transpose to return with `batch_size` first and
|
||||
# reverse the columns to fix the ordering of the classes
|
||||
return torch.flip(winning_values.T, dims=(1, ))
|
||||
|
||||
return winning_values.T # return with `batch_size` first
|
||||
|
||||
|
||||
def stratified_sum_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise sum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
|
||||
fill_value=0.0)
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_min_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise minimum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
|
||||
fill_value=float("inf"))
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_max_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise maximum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
|
||||
fill_value=-1.0 * float("inf"))
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_prod_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise maximum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
|
||||
fill_value=1.0)
|
||||
return winning_values
|
18
prototorch/functions/similarities.py
Normal file
18
prototorch/functions/similarities.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""ProtoTorch similarity functions."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def cosine_similarity(x, y):
|
||||
"""Compute the cosine similarity between :math:`x` and :math:`y`.
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
"""
|
||||
norm_x = x.pow(2).sum(1).sqrt()
|
||||
norm_y = y.pow(2).sum(1).sqrt()
|
||||
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
||||
epsilon = torch.finfo(norm_mat.dtype).eps
|
||||
norm_mat.clamp_(min=epsilon)
|
||||
similarities = (x @ y.T) / norm_mat
|
||||
return similarities
|
32
prototorch/functions/transforms.py
Normal file
32
prototorch/functions/transforms.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
|
||||
|
||||
# Functions
|
||||
def gaussian(distances, variance):
|
||||
return torch.exp(-(distances * distances) / (2 * variance))
|
||||
|
||||
|
||||
def rank_scaled_gaussian(distances, lambd):
|
||||
order = torch.argsort(distances, dim=1)
|
||||
ranks = torch.argsort(order, dim=1)
|
||||
|
||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||
|
||||
|
||||
# Modules
|
||||
class GaussianPrior(torch.nn.Module):
|
||||
def __init__(self, variance):
|
||||
super().__init__()
|
||||
self.variance = variance
|
||||
|
||||
def forward(self, distances):
|
||||
return gaussian(distances, self.variance)
|
||||
|
||||
|
||||
class RankScaledGaussianPrior(torch.nn.Module):
|
||||
def __init__(self, lambd):
|
||||
super().__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
def forward(self, distances):
|
||||
return rank_scaled_gaussian(distances, self.lambd)
|
@@ -0,0 +1,5 @@
|
||||
"""ProtoTorch modules."""
|
||||
|
||||
from .competitions import *
|
||||
from .pooling import *
|
||||
from .wrappers import LambdaLayer, LossLayer
|
||||
|
42
prototorch/modules/competitions.py
Normal file
42
prototorch/modules/competitions.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""ProtoTorch Competition Modules."""
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.competitions import knnc, wtac
|
||||
|
||||
|
||||
class WTAC(torch.nn.Module):
|
||||
"""Winner-Takes-All-Competition Layer.
|
||||
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, distances, labels):
|
||||
return wtac(distances, labels)
|
||||
|
||||
|
||||
class LTAC(torch.nn.Module):
|
||||
"""Loser-Takes-All-Competition Layer.
|
||||
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, probs, labels):
|
||||
return wtac(-1.0 * probs, labels)
|
||||
|
||||
|
||||
class KNNC(torch.nn.Module):
|
||||
"""K-Nearest-Neighbors-Competition.
|
||||
|
||||
Thin wrapper over the `knnc` function.
|
||||
|
||||
"""
|
||||
def __init__(self, k=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.k = k
|
||||
|
||||
def forward(self, distances, labels):
|
||||
return knnc(distances, labels, k=self.k)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"k: {self.k}"
|
@@ -7,15 +7,53 @@ from prototorch.functions.losses import glvq_loss
|
||||
|
||||
|
||||
class GLVQLoss(torch.nn.Module):
|
||||
"""GLVQ Loss."""
|
||||
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
self.beta = beta
|
||||
self.beta = torch.tensor(beta)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
mu = glvq_loss(distances, targets, plabels)
|
||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
||||
|
||||
|
||||
class NeuralGasEnergy(torch.nn.Module):
|
||||
def __init__(self, lm, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.lm = lm
|
||||
|
||||
def forward(self, d):
|
||||
order = torch.argsort(d, dim=1)
|
||||
ranks = torch.argsort(order, dim=1)
|
||||
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
|
||||
|
||||
return cost, order
|
||||
|
||||
def extra_repr(self):
|
||||
return f"lambda: {self.lm}"
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, lm):
|
||||
return torch.exp(-rankings / lm)
|
||||
|
||||
|
||||
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
||||
def __init__(self, topology_layer, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topology_layer = topology_layer
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, topology):
|
||||
winner = rankings[:, 0]
|
||||
|
||||
weights = torch.zeros_like(rankings, dtype=torch.float)
|
||||
weights[torch.arange(rankings.shape[0]), winner] = 1.0
|
||||
|
||||
neighbours = topology.get_neighbours(winner)
|
||||
|
||||
weights[neighbours] = 0.1
|
||||
|
||||
return weights
|
||||
|
170
prototorch/modules/models.py
Normal file
170
prototorch/modules/models.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||
from prototorch.functions.distances import euclidean_distance_matrix
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
r""" Generalized Tangent Learning Vector Quantization
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes: int
|
||||
Number of classes of the given classification problem.
|
||||
|
||||
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
|
||||
Subspace data for the point approximation, required
|
||||
|
||||
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
|
||||
prototype data for initalization of the prototypes used in GTLVQ.
|
||||
|
||||
subspace_size: int (default=256,optional)
|
||||
Subspace dimension of the Projectors. Currently only supported
|
||||
with tagnent_projection_type=global.
|
||||
|
||||
tangent_projection_type: string
|
||||
Specifies the tangent projection type
|
||||
options: local
|
||||
local_proj
|
||||
global
|
||||
local: computes the tangent distances without emphasizing projected
|
||||
data. Only distances are available
|
||||
local_proj: computs tangent distances and returns the projected data
|
||||
for further use. Be careful: data is repeated by number of prototypes
|
||||
global: Number of subspaces is set to one and every prototypes
|
||||
uses the same.
|
||||
|
||||
prototypes_per_class: int (default=2,optional)
|
||||
Number of prototypes per class
|
||||
|
||||
feature_dim: int (default=256)
|
||||
Dimensionality of the feature space specified as integer.
|
||||
Prototype dimension.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The GTLVQ [1] is a prototype-based classification learning model. The
|
||||
GTLVQ uses the Tangent-Distances for a local point approximation
|
||||
of an assumed data manifold via prototypial representations.
|
||||
|
||||
The GTLVQ requires subspace projectors for transforming the data
|
||||
and prototypes into the affine subspace. Every prototype is
|
||||
equipped with a specific subpspace and represents a point
|
||||
approximation of the assumed manifold.
|
||||
|
||||
In practice prototypes and data are projected on this manifold
|
||||
and pairwise euclidean distance computes.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
|
||||
in classification based on manifolc. models and its relation
|
||||
to tangent metric learning. In: 2017 International Joint
|
||||
Conference on Neural Networks (IJCNN).
|
||||
Bd. 2017-May : IEEE, 2017, S. 1756–1765
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
subspace_data=None,
|
||||
prototype_data=None,
|
||||
subspace_size=256,
|
||||
tangent_projection_type="local",
|
||||
prototypes_per_class=2,
|
||||
feature_dim=256,
|
||||
):
|
||||
super(GTLVQ, self).__init__()
|
||||
|
||||
self.num_protos = num_classes * prototypes_per_class
|
||||
self.num_protos_class = prototypes_per_class
|
||||
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
||||
self.feature_dim = feature_dim
|
||||
self.num_classes = num_classes
|
||||
|
||||
cls_initializer = StratifiedMeanInitializer(prototype_data)
|
||||
cls_distribution = {
|
||||
"num_classes": num_classes,
|
||||
"prototypes_per_class": prototypes_per_class,
|
||||
}
|
||||
|
||||
self.cls = LabeledComponents(cls_distribution, cls_initializer)
|
||||
|
||||
if subspace_data is None:
|
||||
raise ValueError("Init Data must be specified!")
|
||||
|
||||
self.tpt = tangent_projection_type
|
||||
with torch.no_grad():
|
||||
if self.tpt == "local":
|
||||
self.init_local_subspace(subspace_data, subspace_size,
|
||||
self.num_protos)
|
||||
elif self.tpt == "global":
|
||||
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||
else:
|
||||
self.subspaces = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.tpt == "local":
|
||||
dis = self.local_tangent_distances(x)
|
||||
elif self.tpt == "gloabl":
|
||||
dis = self.global_tangent_distances(x)
|
||||
else:
|
||||
dis = (x @ self.cls.prototypes.T) / (
|
||||
torch.norm(x, dim=1, keepdim=True) @ torch.norm(
|
||||
self.cls.prototypes, dim=1, keepdim=True).T)
|
||||
return dis
|
||||
|
||||
def init_gobal_subspace(self, data, num_subspaces):
|
||||
_, _, v = torch.svd(data)
|
||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
subspaces = subspace[:, :num_subspaces]
|
||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||
|
||||
def init_local_subspace(self, data, num_subspaces, num_protos):
|
||||
data = data - torch.mean(data, dim=0)
|
||||
_, _, v = torch.svd(data, some=False)
|
||||
v = v[:, :num_subspaces]
|
||||
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
|
||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||
|
||||
def global_tangent_distances(self, x):
|
||||
# Tangent Projection
|
||||
x, projected_prototypes = (
|
||||
x @ self.subspaces,
|
||||
self.cls.prototypes @ self.subspaces,
|
||||
)
|
||||
# Euclidean Distance
|
||||
return euclidean_distance_matrix(x, projected_prototypes)
|
||||
|
||||
def local_tangent_distances(self, x):
|
||||
|
||||
# Tangent Distance
|
||||
x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components,
|
||||
x.size(-1))
|
||||
protos = self.cls()[0].unsqueeze(0).expand(x.size(0),
|
||||
self.cls.num_components,
|
||||
x.size(-1))
|
||||
projectors = torch.eye(
|
||||
self.subspaces.shape[-2], device=x.device) - torch.bmm(
|
||||
self.subspaces, self.subspaces.permute([0, 2, 1]))
|
||||
diff = (x - protos)
|
||||
diff = diff.permute([1, 0, 2])
|
||||
diff = torch.bmm(diff, projectors)
|
||||
diff = torch.norm(diff, 2, dim=-1).T
|
||||
return diff
|
||||
|
||||
def get_parameters(self):
|
||||
return {
|
||||
"params": self.cls.components,
|
||||
}, {
|
||||
"params": self.subspaces
|
||||
}
|
||||
|
||||
def orthogonalize_subspace(self):
|
||||
if self.subspaces is not None:
|
||||
with torch.no_grad():
|
||||
ortho_subpsaces = (orthogonalization(self.subspaces)
|
||||
if self.tpt == "global" else
|
||||
torch.nn.init.orthogonal_(self.subspaces))
|
||||
self.subspaces.copy_(ortho_subpsaces)
|
32
prototorch/modules/pooling.py
Normal file
32
prototorch/modules/pooling.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""ProtoTorch Pooling Modules."""
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.pooling import (stratified_max_pooling,
|
||||
stratified_min_pooling,
|
||||
stratified_prod_pooling,
|
||||
stratified_sum_pooling)
|
||||
|
||||
|
||||
class StratifiedSumPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_sum_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedProdPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_prod_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMinPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_min_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMaxPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_max_pooling(values, labels)
|
@@ -1,57 +0,0 @@
|
||||
"""ProtoTorch prototype modules."""
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
|
||||
|
||||
class AddPrototypes1D(torch.nn.Module):
|
||||
def __init__(self,
|
||||
prototypes_per_class=1,
|
||||
prototype_distribution=None,
|
||||
prototype_initializer='ones',
|
||||
data=None,
|
||||
**kwargs):
|
||||
|
||||
if data is None:
|
||||
if 'input_dim' not in kwargs:
|
||||
raise NameError('`input_dim` required if '
|
||||
'no `data` is provided.')
|
||||
if prototype_distribution is not None:
|
||||
nclasses = sum(prototype_distribution)
|
||||
else:
|
||||
if 'nclasses' not in kwargs:
|
||||
raise NameError('`prototype_distribution` required if '
|
||||
'both `data` and `nclasses` are not '
|
||||
'provided.')
|
||||
nclasses = kwargs.pop('nclasses')
|
||||
input_dim = kwargs.pop('input_dim')
|
||||
# input_shape = (input_dim, )
|
||||
x_train = torch.rand(nclasses, input_dim)
|
||||
y_train = torch.arange(nclasses)
|
||||
|
||||
else:
|
||||
x_train, y_train = data
|
||||
x_train = torch.as_tensor(x_train)
|
||||
y_train = torch.as_tensor(y_train)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.prototypes_per_class = prototypes_per_class
|
||||
with torch.no_grad():
|
||||
if not prototype_distribution:
|
||||
num_classes = torch.unique(y_train).shape[0]
|
||||
self.prototype_distribution = torch.tensor(
|
||||
[self.prototypes_per_class] * num_classes)
|
||||
else:
|
||||
self.prototype_distribution = torch.tensor(
|
||||
prototype_distribution)
|
||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
||||
prototypes, prototype_labels = self.prototype_initializer(
|
||||
x_train,
|
||||
y_train,
|
||||
prototype_distribution=self.prototype_distribution)
|
||||
self.prototypes = torch.nn.Parameter(prototypes)
|
||||
self.prototype_labels = prototype_labels
|
||||
|
||||
def forward(self):
|
||||
return self.prototypes, self.prototype_labels
|
36
prototorch/modules/wrappers.py
Normal file
36
prototorch/modules/wrappers.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""ProtoTorch Wrappers."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LambdaLayer(torch.nn.Module):
|
||||
def __init__(self, fn, name=None):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.name = name or fn.__name__ # lambda fns get <lambda>
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
def extra_repr(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class LossLayer(torch.nn.modules.loss._Loss):
|
||||
def __init__(self,
|
||||
fn,
|
||||
name=None,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction: str = "mean") -> None:
|
||||
super().__init__(size_average=size_average,
|
||||
reduce=reduce,
|
||||
reduction=reduction)
|
||||
self.fn = fn
|
||||
self.name = name or fn.__name__ # lambda fns get <lambda>
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
def extra_repr(self):
|
||||
return self.name
|
0
prototorch/utils/__init__.py
Normal file
0
prototorch/utils/__init__.py
Normal file
46
prototorch/utils/celluloid.py
Normal file
46
prototorch/utils/celluloid.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
from matplotlib.animation import ArtistAnimation
|
||||
from matplotlib.artist import Artist
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
__version__ = "0.2.0"
|
||||
|
||||
|
||||
class Camera:
|
||||
"""Make animations easier."""
|
||||
def __init__(self, figure: Figure) -> None:
|
||||
"""Create camera from matplotlib figure."""
|
||||
self._figure = figure
|
||||
# need to keep track off artists for each axis
|
||||
self._offsets: Dict[str, Dict[int, int]] = {
|
||||
k: defaultdict(int)
|
||||
for k in
|
||||
["collections", "patches", "lines", "texts", "artists", "images"]
|
||||
}
|
||||
self._photos: List[List[Artist]] = []
|
||||
|
||||
def snap(self) -> List[Artist]:
|
||||
"""Capture current state of the figure."""
|
||||
frame_artists: List[Artist] = []
|
||||
for i, axis in enumerate(self._figure.axes):
|
||||
if axis.legend_ is not None:
|
||||
axis.add_artist(axis.legend_)
|
||||
for name in self._offsets:
|
||||
new_artists = getattr(axis, name)[self._offsets[name][i]:]
|
||||
frame_artists += new_artists
|
||||
self._offsets[name][i] += len(new_artists)
|
||||
self._photos.append(frame_artists)
|
||||
return frame_artists
|
||||
|
||||
def animate(self, *args, **kwargs) -> ArtistAnimation:
|
||||
"""Animate the snapshots taken.
|
||||
Uses matplotlib.animation.ArtistAnimation
|
||||
Returns
|
||||
-------
|
||||
ArtistAnimation
|
||||
"""
|
||||
return ArtistAnimation(self._figure, self._photos, *args, **kwargs)
|
78
prototorch/utils/colors.py
Normal file
78
prototorch/utils/colors.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""ProtoFlow color utilities."""
|
||||
|
||||
import matplotlib.lines as mlines
|
||||
from matplotlib import cm
|
||||
from matplotlib.colors import Normalize, to_hex, to_rgb
|
||||
|
||||
|
||||
def color_scheme(n,
|
||||
cmap="viridis",
|
||||
form="hex",
|
||||
tikz=False,
|
||||
zero_indexed=False):
|
||||
"""Return *n* colors from the color scheme.
|
||||
|
||||
Arguments:
|
||||
n (int): number of colors to return
|
||||
|
||||
Keyword Arguments:
|
||||
cmap (str): Name of a matplotlib `colormap\
|
||||
<https://matplotlib.org/3.1.1/gallery/color/colormap_reference.html>`_.
|
||||
form (str): Colorformat (supports "hex" and "rgb").
|
||||
tikz (bool): Output as `TikZ <https://github.com/pgf-tikz/pgf>`_
|
||||
command.
|
||||
zero_indexed (bool): Use zero indexing for output array.
|
||||
|
||||
Returns:
|
||||
(list): List of colors
|
||||
"""
|
||||
cmap = cm.get_cmap(cmap)
|
||||
colornorm = Normalize(vmin=1, vmax=n)
|
||||
hex_map = dict()
|
||||
rgb_map = dict()
|
||||
for cl in range(1, n + 1):
|
||||
if zero_indexed:
|
||||
hex_map[cl - 1] = to_hex(cmap(colornorm(cl)))
|
||||
rgb_map[cl - 1] = to_rgb(cmap(colornorm(cl)))
|
||||
else:
|
||||
hex_map[cl] = to_hex(cmap(colornorm(cl)))
|
||||
rgb_map[cl] = to_rgb(cmap(colornorm(cl)))
|
||||
if tikz:
|
||||
for k, v in rgb_map.items():
|
||||
print(f"\\definecolor{{color-{k}}}{{rgb}}{{{v[0]},{v[1]},{v[2]}}}")
|
||||
if form == "hex":
|
||||
return hex_map
|
||||
elif form == "rgb":
|
||||
return rgb_map
|
||||
else:
|
||||
return hex_map
|
||||
|
||||
|
||||
def get_legend_handles(labels, marker="dots", zero_indexed=False):
|
||||
"""Return matplotlib legend handles and colors."""
|
||||
handles = list()
|
||||
n = len(labels)
|
||||
colors = color_scheme(n,
|
||||
cmap="viridis",
|
||||
form="hex",
|
||||
zero_indexed=zero_indexed)
|
||||
for label, color in zip(labels, colors.values()):
|
||||
if marker == "dots":
|
||||
handle = mlines.Line2D(
|
||||
[],
|
||||
[],
|
||||
color="white",
|
||||
markerfacecolor=color,
|
||||
marker="o",
|
||||
markersize=10,
|
||||
markeredgecolor="k",
|
||||
label=label,
|
||||
)
|
||||
else:
|
||||
handle = mlines.Line2D([], [],
|
||||
color=color,
|
||||
marker="",
|
||||
markersize=15,
|
||||
label=label)
|
||||
handles.append(handle)
|
||||
return handles, colors
|
128
setup.py
128
setup.py
@@ -1,49 +1,89 @@
|
||||
"""Install ProtoTorch."""
|
||||
"""
|
||||
|
||||
from setuptools import setup
|
||||
from setuptools import find_packages
|
||||
######
|
||||
# # ##### #### ##### #### ##### #### ##### #### # #
|
||||
# # # # # # # # # # # # # # # # # #
|
||||
###### # # # # # # # # # # # # # ######
|
||||
# ##### # # # # # # # # ##### # # #
|
||||
# # # # # # # # # # # # # # # # #
|
||||
# # # #### # #### # #### # # #### # #
|
||||
|
||||
PROJECT_URL = 'https://github.com/si-cim/prototorch'
|
||||
DOWNLOAD_URL = 'https://github.com/si-cim/prototorch.git'
|
||||
ProtoTorch Core Package
|
||||
"""
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
with open('README.md', 'r') as fh:
|
||||
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setup(name='prototorch',
|
||||
version='0.1.0-dev0',
|
||||
description='Highly extensible, GPU-supported '
|
||||
'Learning Vector Quantization (LVQ) toolbox '
|
||||
'built using PyTorch and its nn API.',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
author='Jensun Ravichandran',
|
||||
author_email='jjensun@gmail.com',
|
||||
url=PROJECT_URL,
|
||||
download_url=DOWNLOAD_URL,
|
||||
license='MIT',
|
||||
install_requires=[
|
||||
'torch>=1.3.1',
|
||||
'torchvision>=0.5.0',
|
||||
'numpy>=1.9.1',
|
||||
],
|
||||
extras_require={
|
||||
'examples': [
|
||||
'sklearn',
|
||||
'matplotlib',
|
||||
],
|
||||
'tests': ['pytest'],
|
||||
},
|
||||
classifiers=[
|
||||
'Development Status :: 2 - Pre-Alpha', 'Environment :: Console',
|
||||
'Intended Audience :: Developers', 'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Operating System :: OS Independent',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development :: Libraries',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules'
|
||||
],
|
||||
packages=find_packages())
|
||||
INSTALL_REQUIRES = [
|
||||
"torch>=1.3.1",
|
||||
"torchvision>=0.5.1",
|
||||
"numpy>=1.9.1",
|
||||
"sklearn",
|
||||
]
|
||||
DATASETS = [
|
||||
"requests",
|
||||
"tqdm",
|
||||
]
|
||||
DEV = [
|
||||
"bumpversion",
|
||||
"pre-commit",
|
||||
]
|
||||
DOCS = [
|
||||
"recommonmark",
|
||||
"sphinx",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib-katex",
|
||||
"sphinx-autodoc-typehints",
|
||||
]
|
||||
EXAMPLES = [
|
||||
"matplotlib",
|
||||
"torchinfo",
|
||||
]
|
||||
TESTS = ["codecov", "pytest"]
|
||||
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
||||
|
||||
setup(
|
||||
name="prototorch",
|
||||
version="0.5.1",
|
||||
description="Highly extensible, GPU-supported "
|
||||
"Learning Vector Quantization (LVQ) toolbox "
|
||||
"built using PyTorch and its nn API.",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
author="Jensun Ravichandran",
|
||||
author_email="jjensun@gmail.com",
|
||||
url=PROJECT_URL,
|
||||
download_url=DOWNLOAD_URL,
|
||||
license="MIT",
|
||||
install_requires=INSTALL_REQUIRES,
|
||||
extras_require={
|
||||
"docs": DOCS,
|
||||
"datasets": DATASETS,
|
||||
"examples": EXAMPLES,
|
||||
"tests": TESTS,
|
||||
"all": ALL,
|
||||
},
|
||||
classifiers=[
|
||||
"Development Status :: 2 - Pre-Alpha",
|
||||
"Environment :: Console",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Natural Language :: English",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
],
|
||||
packages=find_packages(),
|
||||
zip_safe=False,
|
||||
)
|
||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
26
tests/test_components.py
Normal file
26
tests/test_components.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""ProtoTorch components test suite."""
|
||||
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
def test_labcomps_zeros_init():
|
||||
protos = torch.zeros(3, 2)
|
||||
c = pt.components.LabeledComponents(
|
||||
distribution=[1, 1, 1],
|
||||
initializer=pt.components.Zeros(2),
|
||||
)
|
||||
assert (c.components == protos).any() == True
|
||||
|
||||
|
||||
def test_labcomps_warmstart():
|
||||
protos = torch.randn(3, 2)
|
||||
plabels = torch.tensor([1, 2, 3])
|
||||
c = pt.components.LabeledComponents(
|
||||
distribution=[1, 1, 1],
|
||||
initializer=None,
|
||||
initialized_components=[protos, plabels],
|
||||
)
|
||||
assert (c.components == protos).any() == True
|
||||
assert (c.component_labels == plabels).any() == True
|
95
tests/test_datasets.py
Normal file
95
tests/test_datasets.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""ProtoTorch datasets test suite."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.datasets import abstract, tecator
|
||||
|
||||
|
||||
class TestAbstract(unittest.TestCase):
|
||||
def test_getitem(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
abstract.Dataset("./artifacts")[0]
|
||||
|
||||
def test_len(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
len(abstract.Dataset("./artifacts"))
|
||||
|
||||
|
||||
class TestProtoDataset(unittest.TestCase):
|
||||
def test_getitem(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
abstract.ProtoDataset("./artifacts")[0]
|
||||
|
||||
def test_download(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
abstract.ProtoDataset("./artifacts").download()
|
||||
|
||||
|
||||
class TestTecator(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.artifacts_dir = "./artifacts/Tecator"
|
||||
self._remove_artifacts()
|
||||
|
||||
def _remove_artifacts(self):
|
||||
if os.path.exists(self.artifacts_dir):
|
||||
shutil.rmtree(self.artifacts_dir)
|
||||
|
||||
def test_download_false(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
self._remove_artifacts()
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = tecator.Tecator(rootdir, download=False)
|
||||
|
||||
def test_download_caching(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
||||
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
||||
|
||||
def test_repr(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
||||
self.assertTrue("Split: Train" in train.__repr__())
|
||||
|
||||
def test_download_train(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
train = tecator.Tecator(root=rootdir,
|
||||
train=True,
|
||||
download=True,
|
||||
verbose=False)
|
||||
train = tecator.Tecator(root=rootdir, download=True, verbose=False)
|
||||
x_train, y_train = train.data, train.targets
|
||||
self.assertEqual(x_train.shape[0], 144)
|
||||
self.assertEqual(y_train.shape[0], 144)
|
||||
self.assertEqual(x_train.shape[1], 100)
|
||||
|
||||
def test_download_test(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
x_test, y_test = test.data, test.targets
|
||||
self.assertEqual(x_test.shape[0], 71)
|
||||
self.assertEqual(y_test.shape[0], 71)
|
||||
self.assertEqual(x_test.shape[1], 100)
|
||||
|
||||
def test_class_to_idx(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
_ = test.class_to_idx
|
||||
|
||||
def test_getitem(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
x, y = test[0]
|
||||
self.assertEqual(x.shape[0], 100)
|
||||
self.assertIsInstance(y, int)
|
||||
|
||||
def test_loadable_with_dataloader(self):
|
||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
@@ -6,7 +6,208 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions import (activations, competitions, distances,
|
||||
initializers)
|
||||
initializers, losses, pooling)
|
||||
|
||||
|
||||
class TestActivations(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.flist = ["identity", "sigmoid_beta", "swish_beta"]
|
||||
self.x = torch.randn(1024, 1)
|
||||
|
||||
def test_registry(self):
|
||||
self.assertIsNotNone(activations.ACTIVATIONS)
|
||||
|
||||
def test_funcname_deserialization(self):
|
||||
for funcname in self.flist:
|
||||
f = activations.get_activation(funcname)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
|
||||
# def test_torch_script(self):
|
||||
# for funcname in self.flist:
|
||||
# f = activations.get_activation(funcname)
|
||||
# self.assertIsInstance(f, torch.jit.ScriptFunction)
|
||||
|
||||
def test_callable_deserialization(self):
|
||||
def dummy(x, **kwargs):
|
||||
return x
|
||||
|
||||
for f in [dummy, lambda x: x]:
|
||||
f = activations.get_activation(f)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
self.assertEqual(1, f(1))
|
||||
|
||||
def test_unknown_deserialization(self):
|
||||
for funcname in ["blubb", "foobar"]:
|
||||
with self.assertRaises(NameError):
|
||||
_ = activations.get_activation(funcname)
|
||||
|
||||
def test_identity(self):
|
||||
actual = activations.identity(self.x)
|
||||
desired = self.x
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_sigmoid_beta1(self):
|
||||
actual = activations.sigmoid_beta(self.x, beta=1.0)
|
||||
desired = torch.sigmoid(self.x)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_swish_beta1(self):
|
||||
actual = activations.swish_beta(self.x, beta=1.0)
|
||||
desired = self.x * torch.sigmoid(self.x)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
del self.x
|
||||
|
||||
|
||||
class TestCompetitions(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_wtac(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
actual = competitions.wtac(d, labels)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_wtac_unequal_dist(self):
|
||||
d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
|
||||
labels = torch.tensor([0, 1, 1])
|
||||
actual = competitions.wtac(d, labels)
|
||||
desired = torch.tensor([0, 1])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_wtac_one_hot(self):
|
||||
d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
|
||||
labels = torch.tensor([[0, 1], [1, 0]])
|
||||
actual = competitions.wtac(d, labels)
|
||||
desired = torch.tensor([[0, 1], [1, 0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_knnc_k1(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
actual = competitions.knnc(d, labels, k=1)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestPooling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_stratified_min(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
actual = pooling.stratified_min_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
actual = pooling.stratified_min_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min_trivial(self):
|
||||
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 1, 2])
|
||||
actual = pooling.stratified_min_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_max(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
actual = pooling.stratified_max_pooling(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_max_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 2, 1, 0])
|
||||
labels = torch.nn.functional.one_hot(labels, num_classes=3)
|
||||
actual = pooling.stratified_max_pooling(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_sum(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.LongTensor([0, 0, 1, 2])
|
||||
actual = pooling.stratified_sum_pooling(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_sum_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
actual = pooling.stratified_sum_pooling(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_prod(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
actual = pooling.stratified_prod_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestDistances(unittest.TestCase):
|
||||
@@ -53,12 +254,12 @@ class TestDistances(unittest.TestCase):
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)**2
|
||||
)**2)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
@@ -113,14 +314,14 @@ class TestDistances(unittest.TestCase):
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_lpnorm_pinf(self):
|
||||
actual = distances.lpnorm_distance(self.x, self.y, p=float('inf'))
|
||||
actual = distances.lpnorm_distance(self.x, self.y, p=float("inf"))
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=float('inf'),
|
||||
p=float("inf"),
|
||||
keepdim=False,
|
||||
)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
@@ -134,12 +335,12 @@ class TestDistances(unittest.TestCase):
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)**2
|
||||
)**2)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
@@ -152,12 +353,12 @@ class TestDistances(unittest.TestCase):
|
||||
desired = torch.empty(self.nx, self.ny)
|
||||
for i in range(self.nx):
|
||||
for j in range(self.ny):
|
||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||
self.x[i].reshape(1, -1),
|
||||
self.y[j].reshape(1, -1),
|
||||
p=2,
|
||||
keepdim=False,
|
||||
)**2
|
||||
)**2)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=2)
|
||||
@@ -167,103 +368,16 @@ class TestDistances(unittest.TestCase):
|
||||
del self.x, self.y
|
||||
|
||||
|
||||
class TestActivations(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.x = torch.randn(1024, 1)
|
||||
|
||||
def test_registry(self):
|
||||
self.assertIsNotNone(activations.ACTIVATIONS)
|
||||
|
||||
def test_funcname_deserialization(self):
|
||||
flist = ['identity', 'sigmoid_beta', 'swish_beta']
|
||||
for funcname in flist:
|
||||
f = activations.get_activation(funcname)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
|
||||
def test_callable_deserialization(self):
|
||||
def dummy(x, **kwargs):
|
||||
return x
|
||||
|
||||
for f in [dummy, lambda x: x]:
|
||||
f = activations.get_activation(f)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
self.assertEqual(1, f(1))
|
||||
|
||||
def test_unknown_deserialization(self):
|
||||
for funcname in ['blubb', 'foobar']:
|
||||
with self.assertRaises(NameError):
|
||||
_ = activations.get_activation(funcname)
|
||||
|
||||
def test_identity(self):
|
||||
actual = activations.identity(self.x)
|
||||
desired = self.x
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_sigmoid_beta1(self):
|
||||
actual = activations.sigmoid_beta(self.x, beta=1)
|
||||
desired = torch.sigmoid(self.x)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_swish_beta1(self):
|
||||
actual = activations.swish_beta(self.x, beta=1)
|
||||
desired = self.x * torch.sigmoid(self.x)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
del self.x
|
||||
|
||||
|
||||
class TestCompetitions(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_wtac(self):
|
||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
actual = competitions.wtac(d, labels)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_wtac_one_hot(self):
|
||||
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
||||
labels = torch.tensor([[0, 1], [1, 0]])
|
||||
actual = competitions.wtac(d, labels)
|
||||
desired = torch.tensor([[0, 1], [1, 0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_knnc_k1(self):
|
||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
actual = competitions.knnc(d, labels, k=1)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestInitializers(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.flist = [
|
||||
"zeros",
|
||||
"ones",
|
||||
"rand",
|
||||
"randn",
|
||||
"stratified_mean",
|
||||
"stratified_random",
|
||||
]
|
||||
self.x = torch.tensor(
|
||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
||||
dtype=torch.float32)
|
||||
@@ -274,11 +388,7 @@ class TestInitializers(unittest.TestCase):
|
||||
self.assertIsNotNone(initializers.INITIALIZERS)
|
||||
|
||||
def test_funcname_deserialization(self):
|
||||
flist = [
|
||||
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
|
||||
'stratified_random'
|
||||
]
|
||||
for funcname in flist:
|
||||
for funcname in self.flist:
|
||||
f = initializers.get_initializer(funcname)
|
||||
iscallable = callable(f)
|
||||
self.assertTrue(iscallable)
|
||||
@@ -294,7 +404,7 @@ class TestInitializers(unittest.TestCase):
|
||||
self.assertEqual(1, f(1))
|
||||
|
||||
def test_unknown_deserialization(self):
|
||||
for funcname in ['blubb', 'foobar']:
|
||||
for funcname in ["blubb", "foobar"]:
|
||||
with self.assertRaises(NameError):
|
||||
_ = initializers.get_initializer(funcname)
|
||||
|
||||
@@ -336,8 +446,8 @@ class TestInitializers(unittest.TestCase):
|
||||
|
||||
def test_stratified_mean_equal1(self):
|
||||
pdist = torch.tensor([1, 1])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||
desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
@@ -345,8 +455,9 @@ class TestInitializers(unittest.TestCase):
|
||||
|
||||
def test_stratified_random_equal1(self):
|
||||
pdist = torch.tensor([1, 1])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
||||
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.]])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||
False)
|
||||
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
@@ -354,9 +465,20 @@ class TestInitializers(unittest.TestCase):
|
||||
|
||||
def test_stratified_mean_equal2(self):
|
||||
pdist = torch.tensor([2, 2])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
||||
desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.],
|
||||
[1., 1., 1.]])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||
desired = torch.tensor([[5.0, 5.0, 5.0], [5.0, 5.0, 5.0],
|
||||
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_random_equal2(self):
|
||||
pdist = torch.tensor([2, 2])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||
False)
|
||||
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, -1.0, -2.0],
|
||||
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
@@ -364,9 +486,9 @@ class TestInitializers(unittest.TestCase):
|
||||
|
||||
def test_stratified_mean_unequal(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
||||
[1., 1., 1.]])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||
desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
@@ -374,14 +496,86 @@ class TestInitializers(unittest.TestCase):
|
||||
|
||||
def test_stratified_random_unequal(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
||||
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.], [0., 0., 0.],
|
||||
[0., 0., 0.]])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||
False)
|
||||
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_mean_unequal_one_hot(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
y = torch.eye(2)[self.y]
|
||||
desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
|
||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||
desired1,
|
||||
decimal=5)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual2,
|
||||
desired2,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_random_unequal_one_hot(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
y = torch.eye(2)[self.y]
|
||||
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
|
||||
desired1 = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||
desired1,
|
||||
decimal=5)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual2,
|
||||
desired2,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
del self.x, self.y, self.gen
|
||||
_ = torch.seed()
|
||||
|
||||
|
||||
class TestLosses(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_glvq_loss_int_labels(self):
|
||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||
labels = torch.tensor([0, 1])
|
||||
targets = torch.ones(100)
|
||||
batch_loss = losses.glvq_loss(distances=d,
|
||||
target_labels=targets,
|
||||
prototype_labels=labels)
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
self.assertEqual(loss_value, -100)
|
||||
|
||||
def test_glvq_loss_one_hot_labels(self):
|
||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||
labels = torch.tensor([[0, 1], [1, 0]])
|
||||
wl = torch.tensor([1, 0])
|
||||
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||
batch_loss = losses.glvq_loss(distances=d,
|
||||
target_labels=targets,
|
||||
prototype_labels=labels)
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
self.assertEqual(loss_value, -100)
|
||||
|
||||
def test_glvq_loss_one_hot_unequal(self):
|
||||
dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
|
||||
d = torch.stack(dlist, dim=1)
|
||||
labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
|
||||
wl = torch.tensor([1, 0])
|
||||
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||
batch_loss = losses.glvq_loss(distances=d,
|
||||
target_labels=targets,
|
||||
prototype_labels=labels)
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
self.assertEqual(loss_value, -100)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
@@ -1,129 +0,0 @@
|
||||
"""ProtoTorch modules test suite."""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.modules import prototypes, losses
|
||||
|
||||
|
||||
class TestPrototypes(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.x = torch.tensor(
|
||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
||||
dtype=torch.float32)
|
||||
self.y = torch.tensor([0, 0, 1, 1])
|
||||
self.gen = torch.manual_seed(42)
|
||||
|
||||
def test_addprototypes1d_init_without_input_dim(self):
|
||||
with self.assertRaises(NameError):
|
||||
_ = prototypes.AddPrototypes1D(nclasses=1)
|
||||
|
||||
def test_addprototypes1d_init_without_nclasses(self):
|
||||
with self.assertRaises(NameError):
|
||||
_ = prototypes.AddPrototypes1D(input_dim=1)
|
||||
|
||||
def test_addprototypes1d_init_without_pdist(self):
|
||||
p1 = prototypes.AddPrototypes1D(input_dim=6,
|
||||
nclasses=2,
|
||||
prototypes_per_class=4,
|
||||
prototype_initializer='ones')
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.ones(8, 6)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_init_without_data(self):
|
||||
pdist = [2, 2]
|
||||
p1 = prototypes.AddPrototypes1D(input_dim=3,
|
||||
prototype_distribution=pdist,
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(4, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
# def test_addprototypes1d_init_torch_pdist(self):
|
||||
# pdist = torch.tensor([2, 2])
|
||||
# p1 = prototypes.AddPrototypes1D(input_dim=3,
|
||||
# prototype_distribution=pdist,
|
||||
# prototype_initializer='zeros')
|
||||
# protos = p1.prototypes
|
||||
# actual = protos.detach().numpy()
|
||||
# desired = torch.zeros(4, 3)
|
||||
# mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
# desired,
|
||||
# decimal=5)
|
||||
# self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_init_with_ppc(self):
|
||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
|
||||
prototypes_per_class=2,
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(4, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_init_with_pdist(self):
|
||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
|
||||
prototype_distribution=[6, 9],
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(15, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_func_initializer(self):
|
||||
def my_initializer(*args, **kwargs):
|
||||
return torch.full((2, 99), 99), torch.tensor([0, 1])
|
||||
|
||||
p1 = prototypes.AddPrototypes1D(input_dim=99,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer=my_initializer)
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = 99 * torch.ones(2, 99)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_forward(self):
|
||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y])
|
||||
protos, _ = p1()
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.ones(2, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
del self.x, self.y, self.gen
|
||||
_ = torch.seed()
|
||||
|
||||
|
||||
class TestLosses(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_glvqloss_init(self):
|
||||
_ = losses.GLVQLoss()
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
15
tox.ini
15
tox.ini
@@ -1,15 +0,0 @@
|
||||
# tox (https://tox.readthedocs.io/) is a tool for running tests
|
||||
# in multiple virtualenvs. This configuration file will run the
|
||||
# test suite on all supported python versions. To use it, "pip install tox"
|
||||
# and then run "tox" from this directory.
|
||||
|
||||
[tox]
|
||||
envlist = py36
|
||||
|
||||
[testenv]
|
||||
deps =
|
||||
numpy
|
||||
unittest-xml-reporting
|
||||
commands =
|
||||
python -m xmlrunner -o reports
|
||||
|
Reference in New Issue
Block a user