Compare commits
269 Commits
v0.1.6
...
feature/be
Author | SHA1 | Date | |
---|---|---|---|
|
9bb2e20dce | ||
|
6748951b63 | ||
|
c547af728b | ||
|
482044ec87 | ||
|
45f01f39d4 | ||
|
9ab864fbdf | ||
|
365e0fb931 | ||
|
ba50dfba50 | ||
|
16ca409f07 | ||
|
c3cad19853 | ||
|
ec294bdd37 | ||
|
e0abb1f3de | ||
|
918e599c6a | ||
|
ec61881ca8 | ||
|
5a89f24c10 | ||
|
bcf9c6bdb1 | ||
|
736565b768 | ||
|
94730f492b | ||
|
46ec7b07d7 | ||
|
07dab5a5ca | ||
|
ed83138e1f | ||
|
1be7d7ec09 | ||
|
60d2a1d2c9 | ||
|
be7d7f43bd | ||
|
fe729781fc | ||
|
a7df7be1c8 | ||
|
696719600b | ||
|
48e7c029fa | ||
|
5de3a480c7 | ||
|
626f51ce80 | ||
|
6d7d93c8e8 | ||
|
93b1d0bd46 | ||
|
b7992c01db | ||
|
fcd944d3ff | ||
|
054720dd7b | ||
|
23d1a71b31 | ||
|
e922aae432 | ||
|
3e50d0d817 | ||
|
dc4f31d700 | ||
|
02954044d7 | ||
|
8f08ba66ea | ||
|
e0b92e9ac2 | ||
|
d16a0de202 | ||
|
76fea3f881 | ||
|
c00513ae0d | ||
|
bccef8bef0 | ||
|
29ee326b85 | ||
|
055568dc86 | ||
|
3a7328e290 | ||
|
d6629c8792 | ||
|
ef65bd3789 | ||
|
d096eba2c9 | ||
|
dd34c57e2e | ||
|
5911f4dd90 | ||
|
dbfe315f4f | ||
|
9c90c902dc | ||
|
7d3f59e54b | ||
|
9da47b1dba | ||
|
41f0e77fc9 | ||
|
fab786a07e | ||
|
40bd7ed380 | ||
|
4941c2b89d | ||
|
ce14dec7e9 | ||
|
b31c8cc707 | ||
|
e21e6c7e02 | ||
|
dd696ea1e0 | ||
|
15e7232747 | ||
|
197b728c63 | ||
|
98892afee0 | ||
|
d5855dbe97 | ||
|
75a39f5b03 | ||
|
1a0e697b27 | ||
|
1a17193b35 | ||
|
aaa3c51e0a | ||
|
62c5974a85 | ||
|
1d26226a2f | ||
|
4232d0ed2a | ||
|
a9edf06507 | ||
|
d3bb430104 | ||
|
6ffd27d12a | ||
|
859e2cae69 | ||
|
d7ea89d47e | ||
|
fa928afe2c | ||
|
7d4a041df2 | ||
|
04c51c00c6 | ||
|
62185b38cf | ||
|
7b93cd4ad5 | ||
|
d7834e2cc0 | ||
|
0af8cf36f8 | ||
|
f8ad1d83eb | ||
|
23a3683860 | ||
|
4be9fb81eb | ||
|
9d38123114 | ||
|
0f9f24e36a | ||
|
09e3ef1d0e | ||
|
7b9b767113 | ||
|
f56ec44afe | ||
|
67a20124e8 | ||
|
72af03b991 | ||
|
71602bf38a | ||
|
a1d9657b91 | ||
|
4dc11a3737 | ||
|
2649e3ac31 | ||
|
2b2e4a5f37 | ||
|
72404f7c4e | ||
|
612ee8dc6a | ||
|
d42693a441 | ||
|
e5ac50c9a7 | ||
|
561119ef1d | ||
|
f1f0b313c9 | ||
|
b9eb88a602 | ||
|
7eb496110f | ||
|
0a2da9ae50 | ||
|
4ab0a5a414 | ||
|
8956ee75ad | ||
|
29063dcec4 | ||
|
a37095409b | ||
|
1b420c1f6b | ||
|
7ec5528ade | ||
|
a44219ee47 | ||
|
24ebfdc667 | ||
|
1c658cdc1b | ||
|
1911d4b33e | ||
|
6197d7d5d6 | ||
|
d2856383e2 | ||
|
4eafe88dc4 | ||
|
3afced8662 | ||
|
68034d56f6 | ||
|
97ec15b76a | ||
|
69e5ff3243 | ||
|
c87ed5ba8b | ||
|
fc11d78b38 | ||
|
e62a8e6582 | ||
|
ea33196a50 | ||
|
4ca846997a | ||
|
57f8bec270 | ||
|
022d791ea5 | ||
|
43fc7d1678 | ||
|
c7b5c88776 | ||
|
b031382072 | ||
|
d558fa6a4a | ||
|
34ffeb95bc | ||
|
3aa33fd182 | ||
|
f65a665157 | ||
|
bed753a6e9 | ||
|
b82bb54dbe | ||
|
acac39cff6 | ||
|
19f601fac8 | ||
|
5d2a8226ce | ||
|
016fcb4060 | ||
|
20471bfb1c | ||
|
42d974e08c | ||
|
b0df61d1c3 | ||
|
47db1965ee | ||
|
0bc385fe7b | ||
|
358f27257d | ||
|
bda88149d4 | ||
|
7379c61966 | ||
|
e209bf73d5 | ||
|
1b09b1d57b | ||
|
459f7c24be | ||
|
5918f1cc21 | ||
|
3b02d99ebe | ||
|
64250d0938 | ||
|
86688b26b0 | ||
|
ef6bcc1079 | ||
|
bdacc83185 | ||
|
8851d1bbc9 | ||
|
a3f5d7d113 | ||
|
b2009bb563 | ||
|
398431e7ea | ||
|
8f7deb75dd | ||
|
7743c50725 | ||
|
fcf3a4979c | ||
|
d46fe4a393 | ||
|
88cfd5762e | ||
|
aa42b9e331 | ||
|
91b57b01b1 | ||
|
9eb6476078 | ||
|
98c198d463 | ||
|
ef4d70eee0 | ||
|
7e241ff7d8 | ||
|
757f4e980d | ||
|
5ec2dd47cd | ||
|
930f84d3c7 | ||
|
e8cd4d765c | ||
|
8403b01081 | ||
|
aff6aedd60 | ||
|
1b6843dbbb | ||
|
21023a88d7 | ||
|
9c1a41997b | ||
|
1636c84778 | ||
|
27eccf44d4 | ||
|
8f4d66edf1 | ||
|
2a218c0ede | ||
|
db064b5af1 | ||
|
0ac4ced85d | ||
|
e9d2075fed | ||
|
7b7bc3693d | ||
|
cd73f6c427 | ||
|
a60337ff27 | ||
|
e3392ee952 | ||
|
dade502686 | ||
|
b7edee02c3 | ||
|
41b2a2f496 | ||
|
66e3e51a52 | ||
|
0c1f7a4772 | ||
|
663eb12ad7 | ||
|
3fa1fb54f1 | ||
|
cc49f26b77 | ||
|
db965541fd | ||
|
d091dea6a1 | ||
|
d411e52be4 | ||
|
32d6f95db0 | ||
|
139109804f | ||
|
2cc11ae2e3 | ||
|
72e064338c | ||
|
e7e6bf9173 | ||
|
2aa631f4e6 | ||
|
c6992da123 | ||
|
dcbd0c1e5c | ||
|
b8bca71206 | ||
|
419eca46af | ||
|
5b12629bd9 | ||
|
b60db3174a | ||
|
8ce18f83ce | ||
|
7b4f7d84e0 | ||
|
a5e086ce0d | ||
|
0611f81aba | ||
|
a9382dcd9b | ||
|
0933a88a1b | ||
|
88a34a06ef | ||
|
49f9a12b5f | ||
|
de63eaf15a | ||
|
16dc3cf4eb | ||
|
df061cc2ff | ||
|
969fb34cc3 | ||
|
0204f5eab6 | ||
|
b7fc5df386 | ||
|
faf1a88f99 | ||
|
5ffbd43a7c | ||
|
fdf9443a2c | ||
|
7700bb7f8d | ||
|
eefec19c9b | ||
|
246719b837 | ||
|
a14e3aa611 | ||
|
00cdacf7ae | ||
|
4957e821f6 | ||
|
538256dcb7 | ||
|
d812bb0620 | ||
|
81346785bd | ||
|
7a87636ad7 | ||
|
77b7b59bad | ||
|
6e7d80be88 | ||
|
b7684ae512 | ||
|
ebc42a4aa8 | ||
|
c639836537 | ||
|
d36d685115 | ||
|
b341096757 | ||
|
0eac2ce326 | ||
|
8f9c29bd2b | ||
|
ca39aa00d5 | ||
|
1498c4bde5 | ||
|
59b8ab6643 | ||
|
2a4f184163 | ||
|
265e74dd31 | ||
|
daad018a78 | ||
|
eab1ec72c2 | ||
|
b38acd58a8 |
@@ -1,11 +1,15 @@
|
||||
[bumpversion]
|
||||
current_version = 0.1.6
|
||||
current_version = 1.0.0a8
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?
|
||||
serialize =
|
||||
{major}.{minor}.{patch}-{release}
|
||||
{major}.{minor}.{patch}
|
||||
message = build: bump version {current_version} → {new_version}
|
||||
|
||||
[bumpversion:file:setup.py]
|
||||
|
||||
[bumpversion:file:./prototorch/models/__init__.py]
|
||||
|
||||
[bumpversion:file:./docs/source/conf.py]
|
||||
|
15
.codacy.yml
15
.codacy.yml
@@ -1,15 +0,0 @@
|
||||
# To validate the contents of your configuration file
|
||||
# run the following command in the folder where the configuration file is located:
|
||||
# codacy-analysis-cli validate-configuration --directory `pwd`
|
||||
# To analyse, run:
|
||||
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
|
||||
---
|
||||
engines:
|
||||
pylintpython3:
|
||||
exclude_paths:
|
||||
- config/engines.yml
|
||||
remark-lint:
|
||||
exclude_paths:
|
||||
- config/engines.yml
|
||||
exclude_paths:
|
||||
- 'tests/**'
|
@@ -1,2 +0,0 @@
|
||||
comment:
|
||||
require_changes: yes
|
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Steps to reproduce the behavior**
|
||||
1. ...
|
||||
2. Run script '...' or this snippet:
|
||||
```python
|
||||
import prototorch as pt
|
||||
|
||||
...
|
||||
```
|
||||
3. See errors
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Observed behavior**
|
||||
A clear and concise description of what actually happened.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**System and version information**
|
||||
- OS: [e.g. Ubuntu 20.10]
|
||||
- ProtoTorch Version: [e.g. 0.4.0]
|
||||
- Python Version: [e.g. 3.9.5]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
25
.github/workflows/examples.yml
vendored
Normal file
25
.github/workflows/examples.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
# Thi workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: examples
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'examples/**.py'
|
||||
jobs:
|
||||
cpu:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
- name: Run examples
|
||||
run: |
|
||||
./tests/test_examples.sh examples/
|
76
.github/workflows/pythonapp.yml
vendored
Normal file
76
.github/workflows/pythonapp.yml
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
style:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
compatibility:
|
||||
needs: style
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
||||
os: [ubuntu-latest, windows-latest]
|
||||
exclude:
|
||||
- os: windows-latest
|
||||
python-version: "3.7"
|
||||
- os: windows-latest
|
||||
python-version: "3.8"
|
||||
- os: windows-latest
|
||||
python-version: "3.9"
|
||||
- os: windows-latest
|
||||
python-version: "3.11"
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest
|
||||
publish_pypi:
|
||||
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
|
||||
needs: compatibility
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
pip install wheel
|
||||
- name: Build package
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish a Python distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
17
.gitignore
vendored
17
.gitignore
vendored
@@ -128,8 +128,19 @@ dmypy.json
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Datasets
|
||||
datasets/
|
||||
.vscode/
|
||||
|
||||
# PyTorch-Lightning
|
||||
# Vim
|
||||
*~
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Pytorch Models or Weights
|
||||
# If necessary make exceptions for single pretrained models
|
||||
*.pt
|
||||
|
||||
# Artifacts created by ProtoTorch Models
|
||||
datasets/
|
||||
lightning_logs/
|
||||
examples/_*.py
|
||||
examples/_*.ipynb
|
||||
|
59
.pre-commit-config.yaml
Normal file
59
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,59 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude: (^\.bumpversion\.cfg$|cli_messages\.py)
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- id: check-ast
|
||||
- id: check-case-conflict
|
||||
|
||||
- repo: https://github.com/myint/autoflake
|
||||
rev: v1.7.7
|
||||
hooks:
|
||||
- id: autoflake
|
||||
|
||||
- repo: http://github.com/PyCQA/isort
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.982
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: prototorch
|
||||
additional_dependencies: [types-pkg_resources]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: v0.32.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
|
||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||
rev: v1.9.0
|
||||
hooks:
|
||||
- id: python-use-type-annotations
|
||||
- id: python-no-log-warn
|
||||
- id: python-check-blanket-noqa
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/si-cim/gitlint
|
||||
rev: v0.15.2-unofficial
|
||||
hooks:
|
||||
- id: gitlint
|
||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
||||
|
||||
- repo: https://github.com/dosisod/refurb
|
||||
rev: v1.4.0
|
||||
hooks:
|
||||
- id: refurb
|
27
.readthedocs.yml
Normal file
27
.readthedocs.yml
Normal file
@@ -0,0 +1,27 @@
|
||||
# .readthedocs.yml
|
||||
# Read the Docs configuration file
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
# Required
|
||||
version: 2
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/source/conf.py
|
||||
fail_on_warning: true
|
||||
|
||||
# Build documentation with MkDocs
|
||||
# mkdocs:
|
||||
# configuration: mkdocs.yml
|
||||
|
||||
# Optionally build your docs in additional formats such as PDF and ePub
|
||||
formats: all
|
||||
|
||||
# Optionally set the version of Python and requirements required to build your docs
|
||||
python:
|
||||
version: 3.9
|
||||
install:
|
||||
- method: pip
|
||||
path: .
|
||||
extra_requirements:
|
||||
- all
|
7
.remarkrc
Normal file
7
.remarkrc
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"plugins": [
|
||||
"remark-preset-lint-recommended",
|
||||
["remark-lint-list-item-indent", false],
|
||||
["no-emphasis-as-header", true]
|
||||
]
|
||||
}
|
21
.travis.yml
21
.travis.yml
@@ -1,21 +0,0 @@
|
||||
dist: bionic
|
||||
sudo: false
|
||||
language: python
|
||||
python: 3.8
|
||||
cache:
|
||||
directories:
|
||||
- "./tests/artifacts"
|
||||
install:
|
||||
- pip install .[all] --progress-bar off
|
||||
script:
|
||||
- coverage run -m pytest
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
deploy:
|
||||
provider: pypi
|
||||
username: __token__
|
||||
password:
|
||||
secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
|
||||
on:
|
||||
tags: true
|
||||
skip_existing: true
|
89
README.md
89
README.md
@@ -1,30 +1,58 @@
|
||||
# ProtoTorch Models
|
||||
|
||||
[](https://travis-ci.org/si-cim/prototorch_models)
|
||||
[](https://github.com/si-cim/prototorch_models/releases)
|
||||
[](https://pypi.org/project/prototorch_models/)
|
||||
[](https://github.com/si-cim/prototorch_models/blob/master/LICENSE)
|
||||
|
||||
Pre-packaged prototype-based machine learning models using ProtoTorch and
|
||||
PyTorch-Lightning.
|
||||
|
||||
## Installation
|
||||
|
||||
To install this plugin, first install
|
||||
[ProtoTorch](https://github.com/si-cim/prototorch) with:
|
||||
To install this plugin, simply run the following command:
|
||||
|
||||
```sh
|
||||
git clone https://github.com/si-cim/prototorch.git && cd prototorch
|
||||
pip install -e .
|
||||
pip install prototorch_models
|
||||
```
|
||||
|
||||
and then install the plugin itself with:
|
||||
**Installing the models plugin should automatically install a suitable version
|
||||
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
|
||||
be available for use in your Python environment as `prototorch.models`.
|
||||
|
||||
```sh
|
||||
git clone https://github.com/si-cim/prototorch_models.git && cd prototorch_models
|
||||
pip install -e .
|
||||
```
|
||||
## Available models
|
||||
|
||||
The plugin should then be available for use in your Python environment as
|
||||
`prototorch.models`.
|
||||
### LVQ Family
|
||||
|
||||
- Learning Vector Quantization 1 (LVQ1)
|
||||
- Generalized Learning Vector Quantization (GLVQ)
|
||||
- Generalized Relevance Learning Vector Quantization (GRLVQ)
|
||||
- Generalized Matrix Learning Vector Quantization (GMLVQ)
|
||||
- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
|
||||
- Localized and Generalized Matrix Learning Vector Quantization (LGMLVQ)
|
||||
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
|
||||
- Siamese GLVQ
|
||||
- Cross-Entropy Learning Vector Quantization (CELVQ)
|
||||
- Soft Learning Vector Quantization (SLVQ)
|
||||
- Robust Soft Learning Vector Quantization (RSLVQ)
|
||||
- Probabilistic Learning Vector Quantization (PLVQ)
|
||||
- Median-LVQ
|
||||
|
||||
### Other
|
||||
|
||||
- k-Nearest Neighbors (KNN)
|
||||
- Neural Gas (NG)
|
||||
- Growing Neural Gas (GNG)
|
||||
|
||||
## Work in Progress
|
||||
|
||||
- Classification-By-Components Network (CBC)
|
||||
- Learning Vector Quantization 2.1 (LVQ2.1)
|
||||
- Self-Organizing-Map (SOM)
|
||||
|
||||
## Planned models
|
||||
|
||||
- Generalized Tangent Learning Vector Quantization (GTLVQ)
|
||||
- Self-Incremental Learning Vector Quantization (SILVQ)
|
||||
|
||||
## Development setup
|
||||
|
||||
@@ -53,31 +81,26 @@ pip install -e .[all] # \[all\] if you are using zsh or MacOS
|
||||
```
|
||||
|
||||
To assist in the development process, you may also find it useful to install
|
||||
`yapf`, `isort` and `autoflake`. You can install them easily with `pip`.
|
||||
`yapf`, `isort` and `autoflake`. You can install them easily with `pip`. **Also,
|
||||
please avoid installing Tensorflow in this environment. It is known to cause
|
||||
problems with PyTorch-Lightning.**
|
||||
|
||||
## Available models
|
||||
## Contribution
|
||||
|
||||
- Generalized Learning Vector Quantization (GLVQ)
|
||||
- Generalized Relevance Learning Vector Quantization (GRLVQ)
|
||||
- Generalized Matrix Learning Vector Quantization (GMLVQ)
|
||||
- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
|
||||
- Siamese GLVQ
|
||||
- Neural Gas (NG)
|
||||
This repository contains definition for [git hooks](https://githooks.com).
|
||||
[Pre-commit](https://pre-commit.com) is automatically installed as development
|
||||
dependency with prototorch or you can install it manually with `pip install
|
||||
pre-commit`.
|
||||
|
||||
## Work in Progress
|
||||
Please install the hooks by running:
|
||||
```bash
|
||||
pre-commit install
|
||||
pre-commit install --hook-type commit-msg
|
||||
```
|
||||
before creating the first commit.
|
||||
|
||||
- Classification-By-Components Network (CBC)
|
||||
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
|
||||
|
||||
## Planned models
|
||||
|
||||
- Local-Matrix GMLVQ
|
||||
- Generalized Tangent Learning Vector Quantization (GTLVQ)
|
||||
- Robust Soft Learning Vector Quantization (RSLVQ)
|
||||
- Probabilistic Learning Vector Quantization (PLVQ)
|
||||
- Self-Incremental Learning Vector Quantization (SILVQ)
|
||||
- K-Nearest Neighbors (KNN)
|
||||
- Learning Vector Quantization 1 (LVQ1)
|
||||
The commit will fail if the commit message does not follow the specification
|
||||
provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
|
||||
|
||||
## FAQ
|
||||
|
||||
|
20
docs/Makefile
Normal file
20
docs/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= python3 -m sphinx
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
35
docs/make.bat
Normal file
35
docs/make.bat
Normal file
@@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
BIN
docs/source/_static/img/horizontal-lockup.png
Normal file
BIN
docs/source/_static/img/horizontal-lockup.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 88 KiB |
BIN
docs/source/_static/img/logo.png
Normal file
BIN
docs/source/_static/img/logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 52 KiB |
BIN
docs/source/_static/img/model_tree.png
Normal file
BIN
docs/source/_static/img/model_tree.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 191 KiB |
209
docs/source/conf.py
Normal file
209
docs/source/conf.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file only contains a selection of the most common options. For a full
|
||||
# list see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../"))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "ProtoTorch Models"
|
||||
copyright = "2021, Jensun Ravichandran"
|
||||
author = "Jensun Ravichandran"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
#
|
||||
release = "1.0.0-a8"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
#
|
||||
needs_sphinx = "1.6"
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named "sphinx.ext.*") or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
"recommonmark",
|
||||
"nbsphinx",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.doctest",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.todo",
|
||||
"sphinx.ext.coverage",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib.katex",
|
||||
"sphinxcontrib.bibtex",
|
||||
]
|
||||
|
||||
# https://nbsphinx.readthedocs.io/en/0.8.5/custom-css.html#For-All-Pages
|
||||
nbsphinx_prolog = """
|
||||
.. raw:: html
|
||||
|
||||
<style>
|
||||
.nbinput .prompt,
|
||||
.nboutput .prompt {
|
||||
display: none;
|
||||
}
|
||||
</style>
|
||||
"""
|
||||
|
||||
# katex_prerender = True
|
||||
katex_prerender = False
|
||||
|
||||
napoleon_use_ivar = True
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = "index"
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = []
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use. Choose from:
|
||||
# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
|
||||
# "monokai", "perldoc", "pastie", "borland", "trac", "native", "fruity", "bw",
|
||||
# "vim", "vs", "tango", "rrt", "xcode", "igor", "paraiso-light", "paraiso-dark",
|
||||
# "lovelace", "algol", "algol_nu", "arduino", "rainbo w_dash", "abap",
|
||||
# "solarized-dark", "solarized-light", "sas", "stata", "stata-light",
|
||||
# "stata-dark", "inkpot"]
|
||||
pygments_style = "monokai"
|
||||
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
todo_include_todos = True
|
||||
|
||||
# Disable docstring inheritance
|
||||
autodoc_inherit_docstrings = False
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
# https://sphinx-themes.org/
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
html_logo = "_static/img/logo.png"
|
||||
|
||||
html_theme_options = {
|
||||
"logo_only": True,
|
||||
"display_version": True,
|
||||
"prev_next_buttons_location": "bottom",
|
||||
"style_external_links": False,
|
||||
"style_nav_header_background": "#ffffff",
|
||||
# Toc options
|
||||
"collapse_navigation": True,
|
||||
"sticky_navigation": True,
|
||||
"navigation_depth": 4,
|
||||
"includehidden": True,
|
||||
"titles_only": False,
|
||||
}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ["_static"]
|
||||
|
||||
html_css_files = [
|
||||
"https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.css",
|
||||
]
|
||||
|
||||
# -- Options for HTMLHelp output ------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = "protoflowdoc"
|
||||
|
||||
# -- Options for LaTeX output ---------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ("letterpaper" or "a4paper").
|
||||
#
|
||||
# "papersize": "letterpaper",
|
||||
# The font size ("10pt", "11pt" or "12pt").
|
||||
#
|
||||
# "pointsize": "10pt",
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# "preamble": "",
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# "figure_align": "htbp",
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(
|
||||
master_doc,
|
||||
"prototorch.tex",
|
||||
"ProtoTorch Documentation",
|
||||
"Jensun Ravichandran",
|
||||
"manual",
|
||||
),
|
||||
]
|
||||
|
||||
# -- Options for manual page output ---------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [(master_doc, "ProtoTorch Models",
|
||||
"ProtoTorch Models Plugin Documentation", [author], 1)]
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(
|
||||
master_doc,
|
||||
"prototorch models",
|
||||
"ProtoTorch Models Plugin Documentation",
|
||||
author,
|
||||
"prototorch models",
|
||||
"Prototype-based machine learning Models in ProtoTorch.",
|
||||
"Miscellaneous",
|
||||
),
|
||||
]
|
||||
|
||||
# Example configuration for intersphinx: refer to the Python standard library.
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/3/", None),
|
||||
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||
"torch": ('https://pytorch.org/docs/stable/', None),
|
||||
"pytorch_lightning":
|
||||
("https://pytorch-lightning.readthedocs.io/en/stable/", None),
|
||||
}
|
||||
|
||||
# -- Options for Epub output ----------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-epub-output
|
||||
|
||||
epub_cover = ()
|
||||
version = release
|
||||
|
||||
# -- Options for Bibliography -------------------------------------------
|
||||
bibtex_bibfiles = ['refs.bib']
|
||||
bibtex_reference_style = 'author_year'
|
7
docs/source/custom.rst
Normal file
7
docs/source/custom.rst
Normal file
@@ -0,0 +1,7 @@
|
||||
.. Customize the Models
|
||||
|
||||
Abstract Models
|
||||
========================================
|
||||
.. automodule:: prototorch.models.abstract
|
||||
:members:
|
||||
:undoc-members:
|
49
docs/source/index.rst
Normal file
49
docs/source/index.rst
Normal file
@@ -0,0 +1,49 @@
|
||||
.. ProtoTorch Models documentation master file
|
||||
|
||||
ProtoTorch Models Plugins
|
||||
========================================
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 3
|
||||
|
||||
self
|
||||
tutorial.ipynb
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 3
|
||||
:caption: Library
|
||||
|
||||
library
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 3
|
||||
:caption: Customize
|
||||
|
||||
custom
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 3
|
||||
:caption: Proto Y Architecture
|
||||
|
||||
y-architecture
|
||||
|
||||
About
|
||||
-----------------------------------------
|
||||
`Prototorch Models <https://github.com/si-cim/prototorch_models>`_ is a Plugin
|
||||
for `Prototorch <https://github.com/si-cim/prototorch>`_. It implements common
|
||||
prototype-based Machine Learning algorithms using `PyTorch-Lightning
|
||||
<https://www.pytorchlightning.ai/>`_.
|
||||
|
||||
Library
|
||||
-----------------------------------------
|
||||
Prototorch Models delivers many application ready models.
|
||||
These models have been published in the past and have been adapted to the
|
||||
Prototorch library.
|
||||
|
||||
Customizable
|
||||
-----------------------------------------
|
||||
Prototorch Models also contains the building blocks to build own models with
|
||||
PyTorch-Lightning and Prototorch.
|
117
docs/source/library.rst
Normal file
117
docs/source/library.rst
Normal file
@@ -0,0 +1,117 @@
|
||||
.. Available Models
|
||||
|
||||
Models
|
||||
========================================
|
||||
|
||||
.. image:: _static/img/model_tree.png
|
||||
:width: 600
|
||||
|
||||
Unsupervised Methods
|
||||
-----------------------------------------
|
||||
.. autoclass:: prototorch.models.knn.KNN
|
||||
:members:
|
||||
|
||||
.. autoclass:: prototorch.models.unsupervised.NeuralGas
|
||||
:members:
|
||||
|
||||
.. autoclass:: prototorch.models.unsupervised.GrowingNeuralGas
|
||||
:members:
|
||||
|
||||
Classical Learning Vector Quantization
|
||||
-----------------------------------------
|
||||
Original LVQ models introduced by :cite:t:`kohonen1989`.
|
||||
These heuristic algorithms do not use gradient descent.
|
||||
|
||||
.. autoclass:: prototorch.models.lvq.LVQ1
|
||||
:members:
|
||||
.. autoclass:: prototorch.models.lvq.LVQ21
|
||||
:members:
|
||||
|
||||
It is also possible to use the GLVQ structure as shown by :cite:t:`sato1996` in chapter 4.
|
||||
This allows the use of gradient descent methods.
|
||||
|
||||
.. autoclass:: prototorch.models.glvq.GLVQ1
|
||||
:members:
|
||||
.. autoclass:: prototorch.models.glvq.GLVQ21
|
||||
:members:
|
||||
|
||||
Generalized Learning Vector Quantization
|
||||
-----------------------------------------
|
||||
|
||||
:cite:t:`sato1996` presented a LVQ variant with a cost function called GLVQ.
|
||||
This allows the use of gradient descent methods.
|
||||
|
||||
.. autoclass:: prototorch.models.glvq.GLVQ
|
||||
:members:
|
||||
|
||||
The cost function of GLVQ can be extended by a learnable dissimilarity.
|
||||
These learnable dissimilarities assign relevances to each data dimension during the learning phase.
|
||||
For example GRLVQ :cite:p:`hammer2002` and GMLVQ :cite:p:`schneider2009` .
|
||||
|
||||
.. autoclass:: prototorch.models.glvq.GRLVQ
|
||||
:members:
|
||||
|
||||
.. autoclass:: prototorch.models.glvq.GMLVQ
|
||||
:members:
|
||||
|
||||
The dissimilarity from GMLVQ can be interpreted as a projection into another dataspace.
|
||||
Applying this projection only to the data results in LVQMLN
|
||||
|
||||
.. autoclass:: prototorch.models.glvq.LVQMLN
|
||||
:members:
|
||||
|
||||
The projection idea from GMLVQ can be extended to an arbitrary transformation with learnable parameters.
|
||||
|
||||
.. autoclass:: prototorch.models.glvq.SiameseGLVQ
|
||||
:members:
|
||||
|
||||
Probabilistic Models
|
||||
--------------------------------------------
|
||||
|
||||
Probabilistic variants assume, that the prototypes generate a probability distribution over the classes.
|
||||
For a test sample they return a distribution instead of a class assignment.
|
||||
|
||||
The following two algorithms were presented by :cite:t:`seo2003` .
|
||||
Every prototypes is a center of a gaussian distribution of its class, generating a mixture model.
|
||||
|
||||
.. autoclass:: prototorch.models.probabilistic.SLVQ
|
||||
:members:
|
||||
|
||||
.. autoclass:: prototorch.models.probabilistic.RSLVQ
|
||||
:members:
|
||||
|
||||
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incorporate the winning rank into the prior probability calculation.
|
||||
And second use divergence as loss function.
|
||||
|
||||
.. autoclass:: prototorch.models.probabilistic.PLVQ
|
||||
:members:
|
||||
|
||||
Classification by Component
|
||||
--------------------------------------------
|
||||
|
||||
The Classification by Component (CBC) has been introduced by :cite:t:`saralajew2019` .
|
||||
In a CBC architecture there is no class assigned to the prototypes.
|
||||
Instead the dissimilarities are used in a reasoning process, that favours or rejects a class by a learnable degree.
|
||||
The output of a CBC network is a probability distribution over all classes.
|
||||
|
||||
.. autoclass:: prototorch.models.cbc.CBC
|
||||
:members:
|
||||
|
||||
.. autoclass:: prototorch.models.cbc.ImageCBC
|
||||
:members:
|
||||
|
||||
Visualization
|
||||
========================================
|
||||
|
||||
Visualization is very specific to its application.
|
||||
PrototorchModels delivers visualization for two dimensional data and image data.
|
||||
|
||||
The visualizations can be shown in a separate window and inside a tensorboard.
|
||||
|
||||
.. automodule:: prototorch.models.vis
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
Bibliography
|
||||
========================================
|
||||
.. bibliography::
|
72
docs/source/refs.bib
Normal file
72
docs/source/refs.bib
Normal file
@@ -0,0 +1,72 @@
|
||||
@article{sato1996,
|
||||
title={Generalized learning vector quantization},
|
||||
author={Sato, Atsushi and Yamada, Keiji},
|
||||
journal={Advances in neural information processing systems},
|
||||
pages={423--429},
|
||||
year={1996},
|
||||
publisher={MORGAN KAUFMANN PUBLISHERS},
|
||||
url={http://papers.nips.cc/paper/1113-generalized-learning-vector-quantization.pdf},
|
||||
}
|
||||
|
||||
@book{kohonen1989,
|
||||
doi = {10.1007/978-3-642-88163-3},
|
||||
year = {1989},
|
||||
publisher = {Springer Berlin Heidelberg},
|
||||
author = {Teuvo Kohonen},
|
||||
title = {Self-Organization and Associative Memory}
|
||||
}
|
||||
|
||||
@inproceedings{saralajew2019,
|
||||
author = {Saralajew, Sascha and Holdijk, Lars and Rees, Maike and Asan, Ebubekir and Villmann, Thomas},
|
||||
booktitle = {Advances in Neural Information Processing Systems},
|
||||
title = {Classification-by-Components: Probabilistic Modeling of Reasoning over a Set of Components},
|
||||
url = {https://proceedings.neurips.cc/paper/2019/file/dca5672ff3444c7e997aa9a2c4eb2094-Paper.pdf},
|
||||
volume = {32},
|
||||
year = {2019}
|
||||
}
|
||||
|
||||
@article{seo2003,
|
||||
author = {Seo, Sambu and Obermayer, Klaus},
|
||||
title = "{Soft Learning Vector Quantization}",
|
||||
journal = {Neural Computation},
|
||||
volume = {15},
|
||||
number = {7},
|
||||
pages = {1589-1604},
|
||||
year = {2003},
|
||||
month = {07},
|
||||
doi = {10.1162/089976603321891819},
|
||||
}
|
||||
|
||||
@article{hammer2002,
|
||||
title = {Generalized relevance learning vector quantization},
|
||||
journal = {Neural Networks},
|
||||
volume = {15},
|
||||
number = {8},
|
||||
pages = {1059-1068},
|
||||
year = {2002},
|
||||
doi = {https://doi.org/10.1016/S0893-6080(02)00079-5},
|
||||
author = {Barbara Hammer and Thomas Villmann},
|
||||
}
|
||||
|
||||
@article{schneider2009,
|
||||
author = {Schneider, Petra and Biehl, Michael and Hammer, Barbara},
|
||||
title = "{Adaptive Relevance Matrices in Learning Vector Quantization}",
|
||||
journal = {Neural Computation},
|
||||
volume = {21},
|
||||
number = {12},
|
||||
pages = {3532-3561},
|
||||
year = {2009},
|
||||
month = {12},
|
||||
doi = {10.1162/neco.2009.11-08-908},
|
||||
}
|
||||
|
||||
@InProceedings{villmann2018,
|
||||
author="Villmann, Andrea
|
||||
and Kaden, Marika
|
||||
and Saralajew, Sascha
|
||||
and Villmann, Thomas",
|
||||
title="Probabilistic Learning Vector Quantization with Cross-Entropy for Probabilistic Class Assignments in Classification Learning",
|
||||
booktitle="Artificial Intelligence and Soft Computing",
|
||||
year="2018",
|
||||
publisher="Springer International Publishing",
|
||||
}
|
645
docs/source/tutorial.ipynb
Normal file
645
docs/source/tutorial.ipynb
Normal file
@@ -0,0 +1,645 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7ac5eff0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# A short tutorial for the `prototorch.models` plugin"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "beb83780",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Introduction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43b74278",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
|
||||
"\n",
|
||||
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
|
||||
"\n",
|
||||
"With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4e5d1fad",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Basics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1244b66b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dcb88e8a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import prototorch as pt\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"import torch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1adbe2f8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Building Models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "96663ab1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "819ba756",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = pt.models.GLVQ(\n",
|
||||
" hparams=dict(distribution=[1, 1, 1]),\n",
|
||||
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b37e97c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d2c86903",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n",
|
||||
"\n",
|
||||
"The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "45806052",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9d62c4c6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "504df02c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds = pt.datasets.Iris(dims=[0, 2])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3b8e7756",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"type(train_ds)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bce43afa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds.data.shape, train_ds.targets.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "26a83328",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "67b80fbe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c1185f31",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"type(train_loader)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9b5a8963",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x_batch, y_batch = next(iter(train_loader))\n",
|
||||
"print(f\"{x_batch=}, {y_batch=}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dd492ee2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5176b055",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "46a7a506",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "279e75b7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer = pl.Trainer(max_epochs=2, weights_summary=None)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e496b492",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.fit(model, train_loader)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "497fbff6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### From data to a trained model - a very minimal example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ab069c5d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds = pt.datasets.Iris(dims=[0, 2])\n",
|
||||
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n",
|
||||
"\n",
|
||||
"model = pt.models.GLVQ(\n",
|
||||
" dict(distribution=(3, 2), lr=0.1),\n",
|
||||
" prototypes_initializer=pt.initializers.SMCI(train_ds),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n",
|
||||
"trainer.fit(model, train_loader)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "30c71a93",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Saving/Loading trained models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f74ed2c1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Pytorch Lightning can automatically checkpoint the model during various stages of training, but it also possible to manually save a checkpoint after training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3156658d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ckpt_path = \"./checkpoints/glvq_iris.ckpt\"\n",
|
||||
"trainer.save_checkpoint(ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c1c34055",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loaded_model = pt.models.GLVQ.load_from_checkpoint(ckpt_path, strict=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bbbb08e9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Visualizing decision boundaries in 2D"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "53ca52dc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisGLVQ2D(data=train_ds).visualize(loaded_model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8373531f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Saving/Loading trained weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "937bc458",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In most cases, the checkpointing workflow is sufficient. In some cases however, one might want to only save the trained weights from the model. The disadvantage of this method is that the model has be re-created using compatible initialization parameters before the weights could be loaded."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f2035af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ckpt_path = \"./checkpoints/glvq_iris_weights.pth\"\n",
|
||||
"torch.save(model.state_dict(), ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1206021a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = pt.models.GLVQ(\n",
|
||||
" dict(distribution=(3, 2)),\n",
|
||||
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9f2a4beb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisGLVQ2D(data=train_ds, title=\"Before loading the weights\").visualize(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "528d2fc2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.load(ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec817e6b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.load_state_dict(torch.load(ckpt_path), strict=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a208eab7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisGLVQ2D(data=train_ds, title=\"After loading the weights\").visualize(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f8de748f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "53a64063",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Warm-start a model with prototypes learned from another model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3177c277",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
|
||||
"model = pt.models.SiameseGMLVQ(\n",
|
||||
" dict(input_dim=2,\n",
|
||||
" latent_dim=2,\n",
|
||||
" distribution=(3, 2),\n",
|
||||
" proto_lr=0.0001,\n",
|
||||
" bb_lr=0.0001),\n",
|
||||
" optimizer=torch.optim.Adam,\n",
|
||||
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
|
||||
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
|
||||
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8baee9a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc203088",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1f6a33a5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initializing prototypes with a subset of a dataset (along with transformations)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "946ce341",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import prototorch as pt\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"import torch\n",
|
||||
"from torchvision import transforms\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torchvision.utils import make_grid"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "510d9bd4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from matplotlib import pyplot as plt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ea7c1228",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds = MNIST(\n",
|
||||
" \"~/datasets\",\n",
|
||||
" train=True,\n",
|
||||
" download=True,\n",
|
||||
" transform=transforms.Compose([\n",
|
||||
" transforms.RandomHorizontalFlip(p=1.0),\n",
|
||||
" transforms.RandomVerticalFlip(p=1.0),\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b9eaf5c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"s = int(0.05 * len(train_ds))\n",
|
||||
"init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8c32c9f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"init_ds"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "68a9a8b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = pt.models.ImageGLVQ(\n",
|
||||
" dict(distribution=(10, 1)),\n",
|
||||
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6f23df86",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.imshow(model.get_prototype_grid(num_columns=5))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c23c7b2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "30780927",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"protos, plabels = pt.components.LabeledComponents(\n",
|
||||
" distribution=(10, 5),\n",
|
||||
" components_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
|
||||
")()\n",
|
||||
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4fa69f92",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## FAQs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fa20f9ac",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### How do I Retrieve the prototypes and their respective labels from the model?\n",
|
||||
"\n",
|
||||
"For prototype models, the prototypes can be retrieved (as `torch.tensor`) as `model.prototypes`. You can convert it to a NumPy Array by calling `.numpy()` on the tensor if required.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
">>> model.prototypes.numpy()\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Similarly, the labels of the prototypes can be retrieved via `model.prototype_labels`.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
">>> model.prototype_labels\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ba8215bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### How do I make inferences/predictions/recall with my trained model?\n",
|
||||
"\n",
|
||||
"The models under [prototorch.models](https://github.com/si-cim/prototorch_models) provide a `.predict(x)` method for making predictions. This returns the predicted class labels. It is essential that the input to this method is a `torch.tensor` and not a NumPy array. Model instances are also callable. So, you could also just say `model(x)` as if `model` were just a function. However, this returns a (pseudo)-probability distribution over the classes.\n",
|
||||
"\n",
|
||||
"#### Example\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
">>> y_pred = model.predict(torch.Tensor(x_train)) # returns class labels\n",
|
||||
"```\n",
|
||||
"or, simply\n",
|
||||
"```python\n",
|
||||
">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n",
|
||||
"```"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
71
docs/source/y-architecture.rst
Normal file
71
docs/source/y-architecture.rst
Normal file
@@ -0,0 +1,71 @@
|
||||
.. Documentation of the updated Architecture.
|
||||
|
||||
Proto Y Architecture
|
||||
========================================
|
||||
|
||||
Overview
|
||||
****************************************
|
||||
|
||||
The Proto Y Architecture is a framework for abstract prototype learning methods.
|
||||
|
||||
It divides the problem into multiple steps:
|
||||
|
||||
* **Components** : Recalling the position and metadata of the components/prototypes.
|
||||
* **Backbone** : Apply a mapping function to data and prototypes.
|
||||
* **Comparison** : Calculate a dissimilarity based on the latent positions.
|
||||
* **Competition** : Calculate competition values based on the comparison and the metadata.
|
||||
* **Loss** : Calculate the loss based on the competition values
|
||||
* **Inference** : Predict the output based on the competition values.
|
||||
|
||||
Depending on the phase (Training or Testing) Loss or Inference is used.
|
||||
|
||||
Inheritance Structure
|
||||
****************************************
|
||||
|
||||
The Proto Y Architecture has a single base class that defines all steps and hooks
|
||||
of the architecture.
|
||||
|
||||
.. autoclass:: prototorch.y.architectures.base.BaseYArchitecture
|
||||
|
||||
**Steps**
|
||||
|
||||
Components
|
||||
|
||||
.. automethod:: init_components
|
||||
.. automethod:: components
|
||||
|
||||
Backbone
|
||||
|
||||
.. automethod:: init_backbone
|
||||
.. automethod:: backbone
|
||||
|
||||
Comparison
|
||||
|
||||
.. automethod:: init_comparison
|
||||
.. automethod:: comparison
|
||||
|
||||
Competition
|
||||
|
||||
.. automethod:: init_competition
|
||||
.. automethod:: competition
|
||||
|
||||
Loss
|
||||
|
||||
.. automethod:: init_loss
|
||||
.. automethod:: loss
|
||||
|
||||
Inference
|
||||
|
||||
.. automethod:: init_inference
|
||||
.. automethod:: inference
|
||||
|
||||
**Hooks**
|
||||
|
||||
Torchmetric
|
||||
|
||||
.. automethod:: register_torchmetric
|
||||
|
||||
Hyperparameters
|
||||
****************************************
|
||||
Every model implemented with the Proto Y Architecture has a set of hyperparameters,
|
||||
which is stored in the ``HyperParameters`` attribute of the architecture.
|
@@ -1,47 +0,0 @@
|
||||
"""CBC example using the Iris dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
from sklearn.datasets import load_iris
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=2)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
input_dim=x_train.shape[1],
|
||||
nclasses=3,
|
||||
num_components=5,
|
||||
component_initializer=pt.components.SSI(train_ds, noise=0.01),
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.CBC(hparams)
|
||||
|
||||
# Callbacks
|
||||
dvis = pt.models.VisCBC2D(data=(x_train, y_train),
|
||||
title="CBC Iris Example")
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=200,
|
||||
callbacks=[
|
||||
dvis,
|
||||
],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,40 +0,0 @@
|
||||
"""GLVQ example using the Iris dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
from sklearn.datasets import load_iris
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
nclasses=3,
|
||||
prototypes_per_class=2,
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GLVQ(hparams)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=50,
|
||||
callbacks=[vis],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,51 +0,0 @@
|
||||
"""GLVQ example using the spiral dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class StopOnNaN(pl.Callback):
|
||||
def __init__(self, param):
|
||||
super().__init__()
|
||||
self.param = param
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module, logs={}):
|
||||
if torch.isnan(self.param).any():
|
||||
raise ValueError("NaN encountered. Stopping.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Spiral(n_samples=600, noise=0.6)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=256)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
nclasses=2,
|
||||
prototypes_per_class=20,
|
||||
prototype_initializer=pt.components.SSI(train_ds, noise=1e-7),
|
||||
transfer_function="sigmoid_beta",
|
||||
transfer_beta=10.0,
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GLVQ(hparams)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
|
||||
snan = StopOnNaN(model.proto_layer.components)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=200,
|
||||
callbacks=[vis, snan],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,37 +1,144 @@
|
||||
"""GMLVQ example using all four dimensions of the Iris dataset."""
|
||||
import logging
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.core import SMCI, PCALinearTransformInitializer
|
||||
from prototorch.datasets import Iris
|
||||
from prototorch.models.architectures.base import Steps
|
||||
from prototorch.models.callbacks import (
|
||||
LogTorchmetricCallback,
|
||||
PlotLambdaMatrixToTensorboard,
|
||||
VisGMLVQ2D,
|
||||
)
|
||||
from prototorch.models.library.gmlvq import GMLVQ
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# ##############################################################################
|
||||
|
||||
|
||||
def main():
|
||||
# ------------------------------------------------------------
|
||||
# DATA
|
||||
# ------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
from sklearn.datasets import load_iris
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||
full_dataset = Iris()
|
||||
full_count = len(full_dataset)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=150)
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
nclasses=3,
|
||||
prototypes_per_class=1,
|
||||
input_dim=x_train.shape[1],
|
||||
latent_dim=x_train.shape[1],
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
lr=0.01,
|
||||
train_count = int(full_count * 0.5)
|
||||
val_count = int(full_count * 0.4)
|
||||
test_count = int(full_count * 0.1)
|
||||
|
||||
train_dataset, val_dataset, test_dataset = random_split(
|
||||
full_dataset, (train_count, val_count, test_count))
|
||||
|
||||
# Dataloader
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
shuffle=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
shuffle=False,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
num_workers=0,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GMLVQ(hparams)
|
||||
# ------------------------------------------------------------
|
||||
# HYPERPARAMETERS
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(max_epochs=100)
|
||||
# Select Initializer
|
||||
components_initializer = SMCI(full_dataset)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
# Define Hyperparameters
|
||||
hyperparameters = GMLVQ.HyperParameters(
|
||||
lr=dict(components_layer=0.1, _omega=0),
|
||||
input_dim=4,
|
||||
distribution=dict(
|
||||
num_classes=3,
|
||||
per_class=1,
|
||||
),
|
||||
component_initializer=components_initializer,
|
||||
omega_initializer=PCALinearTransformInitializer,
|
||||
omega_initializer_kwargs=dict(
|
||||
data=train_dataset.dataset[train_dataset.indices][0]))
|
||||
|
||||
# Display the Lambda matrix
|
||||
model.show_lambda()
|
||||
# Create Model
|
||||
model = GMLVQ(hyperparameters)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# TRAINING
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Controlling Callbacks
|
||||
recall = LogTorchmetricCallback(
|
||||
'training_recall',
|
||||
torchmetrics.Recall,
|
||||
num_classes=3,
|
||||
step=Steps.TRAINING,
|
||||
)
|
||||
|
||||
stopping_criterion = LogTorchmetricCallback(
|
||||
'validation_recall',
|
||||
torchmetrics.Recall,
|
||||
num_classes=3,
|
||||
step=Steps.VALIDATION,
|
||||
)
|
||||
|
||||
accuracy = LogTorchmetricCallback(
|
||||
'validation_accuracy',
|
||||
torchmetrics.Accuracy,
|
||||
num_classes=3,
|
||||
step=Steps.VALIDATION,
|
||||
)
|
||||
|
||||
es = EarlyStopping(
|
||||
monitor=stopping_criterion.name,
|
||||
mode="max",
|
||||
patience=10,
|
||||
)
|
||||
|
||||
# Visualization Callback
|
||||
vis = VisGMLVQ2D(data=full_dataset)
|
||||
|
||||
# Define trainer
|
||||
trainer = pl.Trainer(
|
||||
callbacks=[
|
||||
vis,
|
||||
recall,
|
||||
accuracy,
|
||||
stopping_criterion,
|
||||
es,
|
||||
PlotLambdaMatrixToTensorboard(),
|
||||
],
|
||||
max_epochs=100,
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
trainer.test(model, test_loader)
|
||||
|
||||
# Manual save
|
||||
trainer.save_checkpoint("./y_arch.ckpt")
|
||||
|
||||
# Load saved model
|
||||
new_model = GMLVQ.load_from_checkpoint(
|
||||
checkpoint_path="./y_arch.ckpt",
|
||||
strict=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -1,48 +0,0 @@
|
||||
"""Limited Rank Matrix LVQ example using the Tecator dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Tecator(root="~/datasets/", train=True)
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=42)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=32)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
nclasses=2,
|
||||
prototypes_per_class=2,
|
||||
input_dim=100,
|
||||
latent_dim=2,
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
lr=0.001,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GMLVQ(hparams)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# Save the model
|
||||
torch.save(model, "liramlvq_tecator.pt")
|
||||
|
||||
# Load a saved model
|
||||
saved_model = torch.load("liramlvq_tecator.pt")
|
||||
|
||||
# Display the Lambda matrix
|
||||
saved_model.show_lambda()
|
@@ -1,42 +0,0 @@
|
||||
"""Classical LVQ using GLVQ example on the Iris dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
from sklearn.datasets import load_iris
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
nclasses=3,
|
||||
prototypes_per_class=2,
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
#prototype_initializer=pt.components.Random(2),
|
||||
lr=0.005,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.LVQ1(hparams)
|
||||
#model = pt.models.LVQ21(hparams)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=200,
|
||||
callbacks=[vis],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,40 +0,0 @@
|
||||
"""Neural Gas example using the Iris dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepare and pre-process the dataset
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
scaler = StandardScaler()
|
||||
scaler.fit(x_train)
|
||||
x_train = scaler.transform(x_train)
|
||||
|
||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(num_prototypes=30, lr=0.03)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.NeuralGas(hparams)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisNG2D(data=train_ds)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,64 +0,0 @@
|
||||
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class Backbone(torch.nn.Module):
|
||||
"""Two fully connected layers with ReLU activation."""
|
||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.latent_size = latent_size
|
||||
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
|
||||
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.dense1(x))
|
||||
out = self.relu(self.dense2(x))
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
from sklearn.datasets import load_iris
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=2)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
nclasses=3,
|
||||
prototypes_per_class=2,
|
||||
prototype_initializer=pt.components.SMI((x_train, y_train)),
|
||||
proto_lr=0.001,
|
||||
bb_lr=0.001,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.SiameseGLVQ(
|
||||
hparams,
|
||||
backbone_module=Backbone,
|
||||
)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisSiameseGLVQ2D(data=(x_train, y_train), border=0.1)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,8 +1,25 @@
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from .architectures.base import BaseYArchitecture
|
||||
from .architectures.comparison import (
|
||||
OmegaComparisonMixin,
|
||||
SimpleComparisonMixin,
|
||||
)
|
||||
from .architectures.competition import WTACompetitionMixin
|
||||
from .architectures.components import SupervisedArchitecture
|
||||
from .architectures.loss import GLVQLossMixin
|
||||
from .architectures.optimization import (
|
||||
MultipleLearningRateMixin,
|
||||
SingleLearningRateMixin,
|
||||
)
|
||||
|
||||
from .cbc import CBC
|
||||
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ, LVQ1, LVQ21
|
||||
from .neural_gas import NeuralGas
|
||||
from .vis import *
|
||||
__all__ = [
|
||||
'BaseYArchitecture',
|
||||
"OmegaComparisonMixin",
|
||||
"SimpleComparisonMixin",
|
||||
"SingleLearningRateMixin",
|
||||
"MultipleLearningRateMixin",
|
||||
"SupervisedArchitecture",
|
||||
"WTACompetitionMixin",
|
||||
"GLVQLossMixin",
|
||||
]
|
||||
|
||||
__version__ = "0.1.6"
|
||||
__version__ = "1.0.0-a8"
|
||||
|
@@ -1,23 +0,0 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
|
||||
class AbstractLightningModel(pl.LightningModule):
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
||||
scheduler = ExponentialLR(optimizer,
|
||||
gamma=0.99,
|
||||
last_epoch=-1,
|
||||
verbose=False)
|
||||
sch = {
|
||||
"scheduler": scheduler,
|
||||
"interval": "step",
|
||||
} # called after each training step
|
||||
return [optimizer], [sch]
|
||||
|
||||
|
||||
class AbstractPrototypeModel(AbstractLightningModel):
|
||||
@property
|
||||
def prototypes(self):
|
||||
return self.proto_layer.components.detach().cpu()
|
290
prototorch/models/architectures/base.py
Normal file
290
prototorch/models/architectures/base.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Proto Y Architecture
|
||||
|
||||
Network architecture for Component based Learning.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torchmetrics import Metric
|
||||
|
||||
|
||||
class Steps(enumerate):
|
||||
TRAINING = "training"
|
||||
VALIDATION = "validation"
|
||||
TEST = "test"
|
||||
PREDICT = "predict"
|
||||
|
||||
|
||||
class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
@dataclass
|
||||
class HyperParameters:
|
||||
"""
|
||||
Add all hyperparameters in the inherited class.
|
||||
"""
|
||||
...
|
||||
|
||||
# Fields
|
||||
registered_metrics: dict[str, dict[type[Metric], Metric]] = {
|
||||
Steps.TRAINING: {},
|
||||
Steps.VALIDATION: {},
|
||||
Steps.TEST: {},
|
||||
}
|
||||
registered_metric_callbacks: dict[str, dict[type[Metric],
|
||||
set[Callable]]] = {
|
||||
Steps.TRAINING: {},
|
||||
Steps.VALIDATION: {},
|
||||
Steps.TEST: {},
|
||||
}
|
||||
|
||||
# Type Hints for Necessary Fields
|
||||
components_layer: torch.nn.Module
|
||||
|
||||
def __init__(self, hparams) -> None:
|
||||
if isinstance(hparams, dict):
|
||||
self.save_hyperparameters(hparams)
|
||||
# TODO: => Move into Component Child
|
||||
del hparams["initialized_proto_shape"]
|
||||
hparams = self.HyperParameters(**hparams)
|
||||
else:
|
||||
hparams_dict = asdict(hparams)
|
||||
hparams_dict["component_initializer"] = None
|
||||
self.save_hyperparameters(hparams_dict, )
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Common Steps
|
||||
self.init_components(hparams)
|
||||
self.init_backbone(hparams)
|
||||
self.init_comparison(hparams)
|
||||
self.init_competition(hparams)
|
||||
|
||||
# Train Steps
|
||||
self.init_loss(hparams)
|
||||
|
||||
# Inference Steps
|
||||
self.init_inference(hparams)
|
||||
|
||||
# external API
|
||||
def get_competition(self, batch, components):
|
||||
'''
|
||||
Returns the output of the competition layer.
|
||||
'''
|
||||
latent_batch, latent_components = self.backbone(batch, components)
|
||||
# TODO: => Latent Hook
|
||||
comparison_tensor = self.comparison(latent_batch, latent_components)
|
||||
# TODO: => Comparison Hook
|
||||
return comparison_tensor
|
||||
|
||||
def forward(self, batch):
|
||||
'''
|
||||
Returns the prediction.
|
||||
'''
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = (batch, None)
|
||||
# TODO: manage different datatypes?
|
||||
components = self.components_layer()
|
||||
# TODO: => Component Hook
|
||||
comparison_tensor = self.get_competition(batch, components)
|
||||
# TODO: => Competition Hook
|
||||
return self.inference(comparison_tensor, components)
|
||||
|
||||
def predict(self, batch):
|
||||
"""
|
||||
Alias for forward
|
||||
"""
|
||||
return self.forward(batch)
|
||||
|
||||
def forward_comparison(self, batch):
|
||||
'''
|
||||
Returns the Output of the comparison layer.
|
||||
'''
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = (batch, None)
|
||||
# TODO: manage different datatypes?
|
||||
components = self.components_layer()
|
||||
# TODO: => Component Hook
|
||||
return self.get_competition(batch, components)
|
||||
|
||||
def loss_forward(self, batch):
|
||||
'''
|
||||
Returns the output of the loss layer.
|
||||
'''
|
||||
# TODO: manage different datatypes?
|
||||
components = self.components_layer()
|
||||
# TODO: => Component Hook
|
||||
comparison_tensor = self.get_competition(batch, components)
|
||||
# TODO: => Competition Hook
|
||||
return self.loss(comparison_tensor, batch, components)
|
||||
|
||||
# Empty Initialization
|
||||
def init_components(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the components step.
|
||||
"""
|
||||
|
||||
def init_backbone(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the backbone step.
|
||||
"""
|
||||
|
||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the comparison step.
|
||||
"""
|
||||
|
||||
def init_competition(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the competition step.
|
||||
"""
|
||||
|
||||
def init_loss(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the loss step.
|
||||
"""
|
||||
|
||||
def init_inference(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the inference step.
|
||||
"""
|
||||
|
||||
# Empty Steps
|
||||
def components(self):
|
||||
"""
|
||||
This step has no input.
|
||||
|
||||
It returns the components.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The components step has no reasonable default.")
|
||||
|
||||
def backbone(self, batch, components):
|
||||
"""
|
||||
The backbone step receives the data batch and the components.
|
||||
It can transform both by an arbitrary function.
|
||||
|
||||
It returns the transformed batch and components,
|
||||
each of the same length as the original input.
|
||||
"""
|
||||
return batch, components
|
||||
|
||||
def comparison(self, batch, components):
|
||||
"""
|
||||
Takes a batch of size N and the component set of size M.
|
||||
|
||||
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The comparison step has no reasonable default.")
|
||||
|
||||
def competition(self, comparison_measures, components):
|
||||
"""
|
||||
Takes the tensor of comparison measures.
|
||||
|
||||
Assigns a competition vector to each class.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The competition step has no reasonable default.")
|
||||
|
||||
def loss(self, comparison_measures, batch, components):
|
||||
"""
|
||||
Takes the tensor of competition measures.
|
||||
|
||||
Calculates a single loss value
|
||||
"""
|
||||
raise NotImplementedError("The loss step has no reasonable default.")
|
||||
|
||||
def inference(self, comparison_measures, components):
|
||||
"""
|
||||
Takes the tensor of competition measures.
|
||||
|
||||
Returns the inferred vector.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The inference step has no reasonable default.")
|
||||
|
||||
# Y Architecture Hooks
|
||||
|
||||
# internal API, called by models and callbacks
|
||||
def register_torchmetric(
|
||||
self,
|
||||
name: Callable,
|
||||
metric: type[Metric],
|
||||
step: str = Steps.TRAINING,
|
||||
**metric_kwargs,
|
||||
):
|
||||
'''
|
||||
Register a callback for evaluating a torchmetric.
|
||||
'''
|
||||
if step == Steps.PREDICT:
|
||||
raise ValueError("Prediction metrics are not supported.")
|
||||
|
||||
if metric not in self.registered_metrics:
|
||||
self.registered_metrics[step][metric] = metric(**metric_kwargs)
|
||||
self.registered_metric_callbacks[step][metric] = {name}
|
||||
else:
|
||||
self.registered_metric_callbacks[step][metric].add(name)
|
||||
|
||||
def update_metrics_step(self, batch, step):
|
||||
# Prediction Metrics
|
||||
preds = self(batch)
|
||||
|
||||
_, y = batch
|
||||
for metric in self.registered_metrics[step]:
|
||||
instance = self.registered_metrics[step][metric].to(self.device)
|
||||
instance(y, preds.reshape(y.shape))
|
||||
|
||||
def update_metrics_epoch(self, step):
|
||||
for metric in self.registered_metrics[step]:
|
||||
instance = self.registered_metrics[step][metric].to(self.device)
|
||||
value = instance.compute()
|
||||
|
||||
for callback in self.registered_metric_callbacks[step][metric]:
|
||||
callback(value, self)
|
||||
|
||||
instance.reset()
|
||||
|
||||
# Lightning steps
|
||||
# -------------------------------------------------------------------------
|
||||
# >>>> Training
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
self.update_metrics_step(batch, Steps.TRAINING)
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def training_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.TRAINING)
|
||||
|
||||
# >>>> Validation
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.update_metrics_step(batch, Steps.VALIDATION)
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.VALIDATION)
|
||||
|
||||
# >>>> Test
|
||||
def test_step(self, batch, batch_idx):
|
||||
self.update_metrics_step(batch, Steps.TEST)
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def test_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.TEST)
|
||||
|
||||
# >>>> Prediction
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
return self.predict(batch)
|
||||
|
||||
# Check points
|
||||
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
||||
# Compatible with Lightning
|
||||
checkpoint["hyper_parameters"] = {
|
||||
'hparams': checkpoint["hyper_parameters"]
|
||||
}
|
||||
return super().on_save_checkpoint(checkpoint)
|
148
prototorch/models/architectures/comparison.py
Normal file
148
prototorch/models/architectures/comparison.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from prototorch.core.distances import euclidean_distance
|
||||
from prototorch.core.initializers import (
|
||||
AbstractLinearTransformInitializer,
|
||||
EyeLinearTransformInitializer,
|
||||
)
|
||||
from prototorch.models.architectures.base import BaseYArchitecture
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class SimpleComparisonMixin(BaseYArchitecture):
|
||||
"""
|
||||
Simple Comparison
|
||||
|
||||
A comparison layer that only uses the positions of the components
|
||||
and the batch for dissimilarity computation.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
|
||||
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
||||
"""
|
||||
comparison_fn: Callable = euclidean_distance
|
||||
comparison_args: dict = field(default_factory=dict)
|
||||
|
||||
comparison_parameters: dict = field(default_factory=dict)
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
def init_comparison(self, hparams: HyperParameters):
|
||||
self.comparison_layer = LambdaLayer(
|
||||
fn=hparams.comparison_fn,
|
||||
**hparams.comparison_args,
|
||||
)
|
||||
|
||||
self.comparison_kwargs: dict[str, Tensor] = {}
|
||||
|
||||
def comparison(self, batch, components):
|
||||
comp_tensor, _ = components
|
||||
batch_tensor, _ = batch
|
||||
|
||||
comp_tensor = comp_tensor.unsqueeze(1)
|
||||
|
||||
distances = self.comparison_layer(
|
||||
batch_tensor,
|
||||
comp_tensor,
|
||||
**self.comparison_kwargs,
|
||||
)
|
||||
|
||||
return distances
|
||||
|
||||
|
||||
class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||
"""
|
||||
Omega Comparison
|
||||
|
||||
A comparison layer that uses the positions of the components
|
||||
and the batch for dissimilarity computation.
|
||||
"""
|
||||
|
||||
_omega: torch.Tensor
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(SimpleComparisonMixin.HyperParameters):
|
||||
"""
|
||||
input_dim: Necessary Field: The dimensionality of the input.
|
||||
latent_dim:
|
||||
The dimensionality of the latent space. Default: 2.
|
||||
omega_initializer:
|
||||
The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
|
||||
"""
|
||||
input_dim: int | None = None
|
||||
latent_dim: int = 2
|
||||
omega_initializer: type[
|
||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||
omega_initializer_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||
super().init_comparison(hparams)
|
||||
|
||||
# Initialize the omega matrix
|
||||
if hparams.input_dim is None:
|
||||
raise ValueError("input_dim must be specified.")
|
||||
else:
|
||||
omega = hparams.omega_initializer(
|
||||
**hparams.omega_initializer_kwargs).generate(
|
||||
hparams.input_dim,
|
||||
hparams.latent_dim,
|
||||
)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
self.comparison_kwargs = dict(omega=self._omega)
|
||||
|
||||
# Properties
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
'''
|
||||
Omega Matrix. Mapping applied to data and prototypes.
|
||||
'''
|
||||
return self._omega.detach().cpu()
|
||||
|
||||
@property
|
||||
def lambda_matrix(self):
|
||||
'''
|
||||
Lambda Matrix.
|
||||
'''
|
||||
omega = self._omega.detach()
|
||||
lam = omega @ omega.T
|
||||
return lam.detach().cpu()
|
||||
|
||||
@property
|
||||
def relevance_profile(self):
|
||||
'''
|
||||
Relevance Profile. Main Diagonal of the Lambda Matrix.
|
||||
'''
|
||||
return self.lambda_matrix.diag().abs()
|
||||
|
||||
@property
|
||||
def classification_influence_profile(self):
|
||||
'''
|
||||
Classification Influence Profile. Influence of each dimension.
|
||||
'''
|
||||
lam = self.lambda_matrix
|
||||
return lam.abs().sum(0)
|
||||
|
||||
@property
|
||||
def parameter_omega(self):
|
||||
return self._omega
|
||||
|
||||
@parameter_omega.setter
|
||||
def parameter_omega(self, new_omega):
|
||||
with torch.no_grad():
|
||||
self._omega.data.copy_(new_omega)
|
29
prototorch/models/architectures/competition.py
Normal file
29
prototorch/models/architectures/competition.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from prototorch.core.competitions import WTAC
|
||||
from prototorch.models.architectures.base import BaseYArchitecture
|
||||
|
||||
|
||||
class WTACompetitionMixin(BaseYArchitecture):
|
||||
"""
|
||||
Winner Take All Competition
|
||||
|
||||
A competition layer that uses the winner-take-all strategy.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
No hyperparameters.
|
||||
"""
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def init_inference(self, hparams: HyperParameters):
|
||||
self.competition_layer = WTAC()
|
||||
|
||||
def inference(self, comparison_measures, components):
|
||||
comp_labels = components[1]
|
||||
return self.competition_layer(comparison_measures, comp_labels)
|
64
prototorch/models/architectures/components.py
Normal file
64
prototorch/models/architectures/components.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from prototorch.core.components import LabeledComponents
|
||||
from prototorch.core.initializers import (
|
||||
AbstractComponentsInitializer,
|
||||
LabelsInitializer,
|
||||
ZerosCompInitializer,
|
||||
)
|
||||
from prototorch.models import BaseYArchitecture
|
||||
|
||||
|
||||
class SupervisedArchitecture(BaseYArchitecture):
|
||||
"""
|
||||
Supervised Architecture
|
||||
|
||||
An architecture that uses labeled Components as component Layer.
|
||||
"""
|
||||
components_layer: LabeledComponents
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters:
|
||||
"""
|
||||
distribution: A valid prototype distribution. No default possible.
|
||||
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
|
||||
"""
|
||||
distribution: "dict[str, int]"
|
||||
component_initializer: AbstractComponentsInitializer
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def init_components(self, hparams: HyperParameters):
|
||||
if hparams.component_initializer is not None:
|
||||
self.components_layer = LabeledComponents(
|
||||
distribution=hparams.distribution,
|
||||
components_initializer=hparams.component_initializer,
|
||||
labels_initializer=LabelsInitializer(),
|
||||
)
|
||||
proto_shape = self.components_layer.components.shape[1:]
|
||||
self.hparams["initialized_proto_shape"] = proto_shape
|
||||
else:
|
||||
# when restoring a checkpointed model
|
||||
self.components_layer = LabeledComponents(
|
||||
distribution=hparams.distribution,
|
||||
components_initializer=ZerosCompInitializer(
|
||||
self.hparams["initialized_proto_shape"]),
|
||||
)
|
||||
|
||||
# Properties
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@property
|
||||
def prototypes(self):
|
||||
"""
|
||||
Returns the position of the prototypes.
|
||||
"""
|
||||
return self.components_layer.components.detach().cpu()
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
"""
|
||||
Returns the labels of the prototypes.
|
||||
"""
|
||||
return self.components_layer.labels.detach().cpu()
|
42
prototorch/models/architectures/loss.py
Normal file
42
prototorch/models/architectures/loss.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from prototorch.core.losses import GLVQLoss
|
||||
from prototorch.models.architectures.base import BaseYArchitecture
|
||||
|
||||
|
||||
class GLVQLossMixin(BaseYArchitecture):
|
||||
"""
|
||||
GLVQ Loss
|
||||
|
||||
A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
margin: The margin of the GLVQ loss. Default: 0.0.
|
||||
transfer_fn: Transfer function to use. Default: sigmoid_beta.
|
||||
transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}.
|
||||
"""
|
||||
margin: float = 0.0
|
||||
|
||||
transfer_fn: str = "sigmoid_beta"
|
||||
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def init_loss(self, hparams: HyperParameters):
|
||||
self.loss_layer = GLVQLoss(
|
||||
margin=hparams.margin,
|
||||
transfer_fn=hparams.transfer_fn,
|
||||
**hparams.transfer_args,
|
||||
)
|
||||
|
||||
def loss(self, comparison_measures, batch, components):
|
||||
target = batch[1]
|
||||
comp_labels = components[1]
|
||||
loss = self.loss_layer(comparison_measures, target, comp_labels)
|
||||
self.log('loss', loss)
|
||||
return loss
|
73
prototorch/models/architectures/optimization.py
Normal file
73
prototorch/models/architectures/optimization.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from prototorch.models import BaseYArchitecture
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class SingleLearningRateMixin(BaseYArchitecture):
|
||||
"""
|
||||
Single Learning Rate
|
||||
|
||||
All parameters are updated with a single learning rate.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
lr: The learning rate. Default: 0.1.
|
||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||
"""
|
||||
lr: float = 0.1
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Hooks
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def configure_optimizers(self):
|
||||
return self.hparams.optimizer(self.parameters(),
|
||||
lr=self.hparams.lr) # type: ignore
|
||||
|
||||
|
||||
class MultipleLearningRateMixin(BaseYArchitecture):
|
||||
"""
|
||||
Multiple Learning Rates
|
||||
|
||||
Define Different Learning Rates for different parameters.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
lr: The learning rate. Default: 0.1.
|
||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||
"""
|
||||
lr: dict = field(default_factory=dict)
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Hooks
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def configure_optimizers(self):
|
||||
optimizers = []
|
||||
for name, lr in self.hparams.lr.items():
|
||||
if not hasattr(self, name):
|
||||
raise ValueError(f"{name} is not a parameter of {self}")
|
||||
else:
|
||||
model_part = getattr(self, name)
|
||||
if isinstance(model_part, Parameter):
|
||||
optimizers.append(
|
||||
self.hparams.optimizer(
|
||||
[model_part],
|
||||
lr=lr, # type: ignore
|
||||
))
|
||||
elif hasattr(model_part, "parameters"):
|
||||
optimizers.append(
|
||||
self.hparams.optimizer(
|
||||
model_part.parameters(),
|
||||
lr=lr, # type: ignore
|
||||
))
|
||||
return optimizers
|
307
prototorch/models/callbacks.py
Normal file
307
prototorch/models/callbacks.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import logging
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Optional, Type
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.models.architectures.base import BaseYArchitecture, Steps
|
||||
from prototorch.models.architectures.comparison import OmegaComparisonMixin
|
||||
from prototorch.models.library.gmlvq import GMLVQ
|
||||
from prototorch.models.vis import Vis2DAbstract
|
||||
from prototorch.utils.utils import mesh2d
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
||||
DIVERGING_COLOR_MAPS = [
|
||||
'PiYG',
|
||||
'PRGn',
|
||||
'BrBG',
|
||||
'PuOr',
|
||||
'RdGy',
|
||||
'RdBu',
|
||||
'RdYlBu',
|
||||
'RdYlGn',
|
||||
'Spectral',
|
||||
'coolwarm',
|
||||
'bwr',
|
||||
'seismic',
|
||||
]
|
||||
|
||||
|
||||
class LogTorchmetricCallback(pl.Callback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
metric: Type[torchmetrics.Metric],
|
||||
step: str = Steps.TRAINING,
|
||||
on_epoch=True,
|
||||
**metric_kwargs,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.metric = metric
|
||||
self.metric_kwargs = metric_kwargs
|
||||
self.step = step
|
||||
self.on_epoch = on_epoch
|
||||
|
||||
def setup(
|
||||
self,
|
||||
trainer: pl.Trainer,
|
||||
pl_module: BaseYArchitecture,
|
||||
stage: Optional[str] = None,
|
||||
) -> None:
|
||||
pl_module.register_torchmetric(
|
||||
self,
|
||||
self.metric,
|
||||
step=self.step,
|
||||
**self.metric_kwargs,
|
||||
)
|
||||
|
||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||
pl_module.log(
|
||||
self.name,
|
||||
value,
|
||||
on_epoch=self.on_epoch,
|
||||
on_step=(not self.on_epoch),
|
||||
)
|
||||
|
||||
|
||||
class LogConfusionMatrix(LogTorchmetricCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
name="confusion",
|
||||
on='prediction',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
torchmetrics.ConfusionMatrix,
|
||||
on=on,
|
||||
num_classes=num_classes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(value.detach().cpu().numpy())
|
||||
|
||||
# Show all ticks and label them with the respective list entries
|
||||
# ax.set_xticks(np.arange(len(farmers)), labels=farmers)
|
||||
# ax.set_yticks(np.arange(len(vegetables)), labels=vegetables)
|
||||
|
||||
# Rotate the tick labels and set their alignment.
|
||||
plt.setp(
|
||||
ax.get_xticklabels(),
|
||||
rotation=45,
|
||||
ha="right",
|
||||
rotation_mode="anchor",
|
||||
)
|
||||
|
||||
# Loop over data dimensions and create text annotations.
|
||||
for i in range(len(value)):
|
||||
for j in range(len(value)):
|
||||
text = ax.text(
|
||||
j,
|
||||
i,
|
||||
value[i, j].item(),
|
||||
ha="center",
|
||||
va="center",
|
||||
color="w",
|
||||
)
|
||||
|
||||
ax.set_title(self.name)
|
||||
fig.tight_layout()
|
||||
|
||||
pl_module.logger.experiment.add_figure(
|
||||
tag=self.name,
|
||||
figure=fig,
|
||||
close=True,
|
||||
global_step=pl_module.global_step,
|
||||
)
|
||||
|
||||
|
||||
class VisGLVQ2D(Vis2DAbstract):
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
ax = self.setup_ax()
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
if x_train is not None:
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
mesh_input, xx, yy = mesh2d(
|
||||
np.vstack([x_train, protos]),
|
||||
self.border,
|
||||
self.resolution,
|
||||
)
|
||||
else:
|
||||
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
||||
_components = pl_module.components_layer.components
|
||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||
y_pred = pl_module.predict(mesh_input)
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
|
||||
class VisGMLVQ2D(Vis2DAbstract):
|
||||
|
||||
def __init__(self, *args, ev_proj=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ev_proj = ev_proj
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
device = pl_module.device
|
||||
omega = pl_module._omega.detach()
|
||||
lam = omega @ omega.T
|
||||
u, _, _ = torch.pca_lowrank(lam, q=2)
|
||||
with torch.no_grad():
|
||||
x_train = torch.Tensor(x_train).to(device)
|
||||
x_train = x_train @ u
|
||||
x_train = x_train.cpu().detach()
|
||||
if self.show_protos:
|
||||
with torch.no_grad():
|
||||
protos = torch.Tensor(protos).to(device)
|
||||
protos = protos @ u
|
||||
protos = protos.cpu().detach()
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
if self.show_protos:
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
|
||||
|
||||
class PlotLambdaMatrixToTensorboard(pl.Callback):
|
||||
|
||||
def __init__(self, cmap='seismic') -> None:
|
||||
super().__init__()
|
||||
self.cmap = cmap
|
||||
|
||||
if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str:
|
||||
warnings.warn(
|
||||
f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}"
|
||||
)
|
||||
|
||||
def on_train_start(self, trainer, pl_module: GMLVQ):
|
||||
self.plot_lambda(trainer, pl_module)
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
|
||||
self.plot_lambda(trainer, pl_module)
|
||||
|
||||
def plot_lambda(self, trainer, pl_module: GMLVQ):
|
||||
|
||||
self.fig, self.ax = plt.subplots(1, 1)
|
||||
|
||||
# plot lambda matrix
|
||||
l_matrix = pl_module.lambda_matrix
|
||||
|
||||
# normalize lambda matrix
|
||||
l_matrix = l_matrix / torch.max(torch.abs(l_matrix))
|
||||
|
||||
# plot lambda matrix
|
||||
self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1)
|
||||
|
||||
self.fig.colorbar(self.ax.images[-1])
|
||||
|
||||
# add title
|
||||
self.ax.set_title('Lambda Matrix')
|
||||
|
||||
# add to tensorboard
|
||||
if isinstance(trainer.logger, TensorBoardLogger):
|
||||
trainer.logger.experiment.add_figure(
|
||||
"lambda_matrix",
|
||||
self.fig,
|
||||
trainer.global_step,
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
|
||||
)
|
||||
|
||||
|
||||
class Profiles(Enum):
|
||||
'''
|
||||
Available Profiles
|
||||
'''
|
||||
RELEVANCE = 'relevance'
|
||||
INFLUENCE = 'influence'
|
||||
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class PlotMatrixProfiles(pl.Callback):
|
||||
|
||||
def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None:
|
||||
super().__init__()
|
||||
self.cmap = cmap
|
||||
self.profile = profile
|
||||
|
||||
def on_train_start(self, trainer, pl_module: GMLVQ):
|
||||
'''
|
||||
Plot initial profile.
|
||||
'''
|
||||
self._plot_profile(trainer, pl_module)
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
|
||||
'''
|
||||
Plot after every epoch.
|
||||
'''
|
||||
self._plot_profile(trainer, pl_module)
|
||||
|
||||
def _plot_profile(self, trainer, pl_module: GMLVQ):
|
||||
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
|
||||
# plot lambda matrix
|
||||
l_matrix = torch.abs(pl_module.lambda_matrix)
|
||||
|
||||
if self.profile == Profiles.RELEVANCE:
|
||||
profile_value = l_matrix.diag()
|
||||
elif self.profile == Profiles.INFLUENCE:
|
||||
profile_value = l_matrix.sum(0)
|
||||
|
||||
# plot lambda matrix
|
||||
ax.plot(profile_value.detach().numpy())
|
||||
|
||||
# add title
|
||||
ax.set_title(f'{self.profile} profile')
|
||||
|
||||
# add to tensorboard
|
||||
if isinstance(trainer.logger, TensorBoardLogger):
|
||||
trainer.logger.experiment.add_figure(
|
||||
f"{self.profile}_matrix",
|
||||
fig,
|
||||
trainer.global_step,
|
||||
)
|
||||
else:
|
||||
class_name = self.__class__.__name__
|
||||
logger_name = trainer.logger.__class__.__name__
|
||||
warnings.warn(
|
||||
f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead."
|
||||
)
|
||||
|
||||
|
||||
class OmegaTraceNormalization(pl.Callback):
|
||||
'''
|
||||
Trace normalization of the Omega Matrix.
|
||||
'''
|
||||
__epsilon = torch.finfo(torch.float32).eps
|
||||
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer",
|
||||
pl_module: OmegaComparisonMixin) -> None:
|
||||
|
||||
omega = pl_module.parameter_omega
|
||||
denominator = torch.sqrt(torch.trace(omega.T @ omega))
|
||||
logging.debug(
|
||||
"Apply Omega Trace Normalization: demoninator=%f",
|
||||
denominator.item(),
|
||||
)
|
||||
pl_module.parameter_omega = omega / (denominator + self.__epsilon)
|
@@ -1,165 +0,0 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components.components import Components
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.similarities import cosine_similarity
|
||||
|
||||
|
||||
def rescaled_cosine_similarity(x, y):
|
||||
"""Cosine Similarity rescaled to [0, 1]."""
|
||||
similarities = cosine_similarity(x, y)
|
||||
return (similarities + 1.0) / 2.0
|
||||
|
||||
|
||||
def shift_activation(x):
|
||||
return (x + 1.0) / 2.0
|
||||
|
||||
|
||||
def euclidean_similarity(x, y):
|
||||
d = euclidean_distance(x, y)
|
||||
return torch.exp(-d * 3)
|
||||
|
||||
|
||||
class CosineSimilarity(torch.nn.Module):
|
||||
def __init__(self, activation=shift_activation):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x, y):
|
||||
epsilon = torch.finfo(x.dtype).eps
|
||||
normed_x = (x / x.pow(2).sum(dim=tuple(range(
|
||||
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
||||
start_dim=1)
|
||||
normed_y = (y / y.pow(2).sum(dim=tuple(range(
|
||||
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
||||
start_dim=1)
|
||||
# normed_x = (x / torch.linalg.norm(x, dim=1))
|
||||
diss = torch.inner(normed_x, normed_y)
|
||||
return self.activation(diss)
|
||||
|
||||
|
||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||
def __init__(self,
|
||||
margin=0.3,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction="mean"):
|
||||
super().__init__(size_average, reduce, reduction)
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, input_, target):
|
||||
dp = torch.sum(target * input_, dim=-1)
|
||||
dm = torch.max(input_ - target, dim=-1).values
|
||||
return torch.nn.functional.relu(dm - dp + self.margin)
|
||||
|
||||
|
||||
class ReasoningLayer(torch.nn.Module):
|
||||
def __init__(self, n_components, n_classes, n_replicas=1):
|
||||
super().__init__()
|
||||
self.n_replicas = n_replicas
|
||||
self.n_classes = n_classes
|
||||
probabilities_init = torch.zeros(2, 1, n_components, self.n_classes)
|
||||
probabilities_init.uniform_(0.4, 0.6)
|
||||
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
pk = self.reasoning_probabilities[0]
|
||||
nk = (1 - pk) * self.reasoning_probabilities[1]
|
||||
ik = 1 - pk - nk
|
||||
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
||||
return img.unsqueeze(1)
|
||||
|
||||
def forward(self, detections):
|
||||
pk = self.reasoning_probabilities[0].clamp(0, 1)
|
||||
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
|
||||
epsilon = torch.finfo(pk.dtype).eps
|
||||
numerator = (detections @ (pk - nk)) + nk.sum(1)
|
||||
probs = numerator / (pk + nk).sum(1)
|
||||
probs = probs.squeeze(0)
|
||||
return probs
|
||||
|
||||
|
||||
class CBC(pl.LightningModule):
|
||||
"""Classification-By-Components."""
|
||||
def __init__(self,
|
||||
hparams,
|
||||
margin=0.1,
|
||||
backbone_class=torch.nn.Identity,
|
||||
similarity=euclidean_similarity,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.save_hyperparameters(hparams)
|
||||
self.margin = margin
|
||||
self.component_layer = Components(self.hparams.num_components,
|
||||
self.hparams.component_initializer)
|
||||
# self.similarity = CosineSimilarity()
|
||||
self.similarity = similarity
|
||||
self.backbone = backbone_class()
|
||||
self.backbone_dependent = backbone_class().requires_grad_(False)
|
||||
n_components = self.components.shape[0]
|
||||
self.reasoning_layer = ReasoningLayer(n_components=n_components,
|
||||
n_classes=self.hparams.nclasses)
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
return self.component_layer.components.detach().cpu()
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
return self.reasoning_layer.reasonings.cpu()
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
||||
return optimizer
|
||||
|
||||
def sync_backbones(self):
|
||||
master_state = self.backbone.state_dict()
|
||||
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
||||
|
||||
def forward(self, x):
|
||||
self.sync_backbones()
|
||||
protos = self.component_layer()
|
||||
|
||||
latent_x = self.backbone(x)
|
||||
latent_protos = self.backbone_dependent(protos)
|
||||
|
||||
detections = self.similarity(latent_x, latent_protos)
|
||||
probs = self.reasoning_layer(detections)
|
||||
return probs
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
x, y = train_batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_pred = self(x)
|
||||
nclasses = self.reasoning_layer.n_classes
|
||||
y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
|
||||
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
|
||||
self.log("train_loss", loss)
|
||||
self.train_acc(y_pred, y_true)
|
||||
self.log(
|
||||
"acc",
|
||||
self.train_acc,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
)
|
||||
return loss
|
||||
|
||||
def predict(self, x):
|
||||
with torch.no_grad():
|
||||
y_pred = self(x)
|
||||
y_pred = torch.argmax(y_pred, dim=1)
|
||||
return y_pred.numpy()
|
||||
|
||||
|
||||
class ImageCBC(CBC):
|
||||
"""CBC model that constrains the components to the range [0, 1] by
|
||||
clamping after updates.
|
||||
"""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
|
||||
self.component_layer.prototypes.data.clamp_(0.0, 1.0)
|
@@ -1,312 +0,0 @@
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
||||
squared_euclidean_distance)
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
|
||||
from .abstract import AbstractPrototypeModel
|
||||
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
|
||||
class GLVQ(AbstractPrototypeModel):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Default Values
|
||||
self.hparams.setdefault("distance", euclidean_distance)
|
||||
self.hparams.setdefault("optimizer", torch.optim.Adam)
|
||||
self.hparams.setdefault("transfer_function", "identity")
|
||||
self.hparams.setdefault("transfer_beta", 10.0)
|
||||
|
||||
self.proto_layer = LabeledComponents(
|
||||
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
|
||||
initializer=self.hparams.prototype_initializer)
|
||||
|
||||
self.transfer_function = get_activation(self.hparams.transfer_function)
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
|
||||
self.loss = glvq_loss
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.component_labels.detach().cpu()
|
||||
|
||||
def forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
dis = self.hparams.distance(x, protos)
|
||||
return dis
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
x, y = train_batch
|
||||
x = x.view(x.size(0), -1) # flatten
|
||||
dis = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||
batch_loss = self.transfer_function(mu,
|
||||
beta=self.hparams.transfer_beta)
|
||||
loss = batch_loss.sum(dim=0)
|
||||
|
||||
# Compute training accuracy
|
||||
with torch.no_grad():
|
||||
preds = wtac(dis, plabels)
|
||||
|
||||
self.train_acc(preds.int(), y.int())
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
|
||||
# Logging
|
||||
self.log("train_loss", loss)
|
||||
self.log("acc",
|
||||
self.train_acc,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
|
||||
return loss
|
||||
|
||||
def predict(self, x):
|
||||
# model.eval() # ?!
|
||||
with torch.no_grad():
|
||||
d = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred.numpy()
|
||||
|
||||
|
||||
class LVQ1(GLVQ):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = lvq1_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
|
||||
scheduler = ExponentialLR(optimizer,
|
||||
gamma=0.99,
|
||||
last_epoch=-1,
|
||||
verbose=False)
|
||||
sch = {
|
||||
"scheduler": scheduler,
|
||||
"interval": "step",
|
||||
} # called after each training step
|
||||
return [optimizer], [sch]
|
||||
|
||||
|
||||
class LVQ21(GLVQ):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = lvq21_loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
|
||||
scheduler = ExponentialLR(optimizer,
|
||||
gamma=0.99,
|
||||
last_epoch=-1,
|
||||
verbose=False)
|
||||
sch = {
|
||||
"scheduler": scheduler,
|
||||
"interval": "step",
|
||||
} # called after each training step
|
||||
return [optimizer], [sch]
|
||||
|
||||
|
||||
class ImageGLVQ(GLVQ):
|
||||
"""GLVQ for training on image data.
|
||||
|
||||
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
after updates.
|
||||
|
||||
"""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
||||
|
||||
class SiameseGLVQ(GLVQ):
|
||||
"""GLVQ in a Siamese setting.
|
||||
|
||||
GLVQ model that applies an arbitrary transformation on the inputs and the
|
||||
prototypes before computing the distances between them. The weights in the
|
||||
transformation pipeline are only learned from the inputs.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
hparams,
|
||||
backbone_module=torch.nn.Identity,
|
||||
backbone_params={},
|
||||
sync=True,
|
||||
**kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.backbone = backbone_module(**backbone_params)
|
||||
self.backbone_dependent = backbone_module(
|
||||
**backbone_params).requires_grad_(False)
|
||||
self.sync = sync
|
||||
|
||||
def sync_backbones(self):
|
||||
master_state = self.backbone.state_dict()
|
||||
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optim = self.hparams.optimizer
|
||||
proto_opt = optim(self.proto_layer.parameters(),
|
||||
lr=self.hparams.proto_lr)
|
||||
if list(self.backbone.parameters()):
|
||||
# only add an optimizer is the backbone has trainable parameters
|
||||
# otherwise, the next line fails
|
||||
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
|
||||
return proto_opt, bb_opt
|
||||
else:
|
||||
return proto_opt
|
||||
|
||||
def forward(self, x):
|
||||
if self.sync:
|
||||
self.sync_backbones()
|
||||
protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
latent_protos = self.backbone_dependent(protos)
|
||||
dis = euclidean_distance(latent_x, latent_protos)
|
||||
return dis
|
||||
|
||||
def predict_latent(self, x):
|
||||
"""Predict `x` assuming it is already embedded in the latent space.
|
||||
|
||||
Only the prototypes are embedded in the latent space using the
|
||||
backbone.
|
||||
|
||||
"""
|
||||
# model.eval() # ?!
|
||||
with torch.no_grad():
|
||||
protos, plabels = self.proto_layer()
|
||||
latent_protos = self.backbone_dependent(protos)
|
||||
d = euclidean_distance(x, latent_protos)
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred.numpy()
|
||||
|
||||
|
||||
class GRLVQ(GLVQ):
|
||||
"""Generalized Relevance Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.relevances = torch.nn.parameter.Parameter(
|
||||
torch.ones(self.hparams.input_dim))
|
||||
|
||||
def forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
dis = omega_distance(x, protos, torch.diag(self.relevances))
|
||||
return dis
|
||||
|
||||
def backbone(self, x):
|
||||
return x @ torch.diag(self.relevances)
|
||||
|
||||
@property
|
||||
def relevance_profile(self):
|
||||
return self.relevances.detach().cpu()
|
||||
|
||||
def predict_latent(self, x):
|
||||
"""Predict `x` assuming it is already embedded in the latent space.
|
||||
|
||||
Only the prototypes are embedded in the latent space using the
|
||||
backbone.
|
||||
|
||||
"""
|
||||
# model.eval() # ?!
|
||||
with torch.no_grad():
|
||||
protos, plabels = self.proto_layer()
|
||||
latent_protos = protos @ torch.diag(self.relevances)
|
||||
d = squared_euclidean_distance(x, latent_protos)
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred.numpy()
|
||||
|
||||
|
||||
class GMLVQ(GLVQ):
|
||||
"""Generalized Matrix Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.omega_layer = torch.nn.Linear(self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
bias=False)
|
||||
|
||||
# Namespace hook for the visualization callbacks to work
|
||||
self.backbone = self.omega_layer
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
return self.omega_layer.weight.detach().cpu()
|
||||
|
||||
@property
|
||||
def lambda_matrix(self):
|
||||
omega = self.omega_layer.weight # (latent_dim, input_dim)
|
||||
lam = omega.T @ omega
|
||||
return lam.detach().cpu()
|
||||
|
||||
def show_lambda(self):
|
||||
import matplotlib.pyplot as plt
|
||||
title = "Lambda matrix"
|
||||
plt.figure(title)
|
||||
plt.title(title)
|
||||
plt.imshow(self.lambda_matrix, cmap="gray")
|
||||
plt.axis("off")
|
||||
plt.colorbar()
|
||||
plt.show(block=True)
|
||||
|
||||
def forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
latent_x = self.omega_layer(x)
|
||||
latent_protos = self.omega_layer(protos)
|
||||
dis = squared_euclidean_distance(latent_x, latent_protos)
|
||||
return dis
|
||||
|
||||
def predict_latent(self, x):
|
||||
"""Predict `x` assuming it is already embedded in the latent space.
|
||||
|
||||
Only the prototypes are embedded in the latent space using the
|
||||
backbone.
|
||||
|
||||
"""
|
||||
# model.eval() # ?!
|
||||
with torch.no_grad():
|
||||
protos, plabels = self.proto_layer()
|
||||
latent_protos = self.omega_layer(protos)
|
||||
d = squared_euclidean_distance(x, latent_protos)
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred.numpy()
|
||||
|
||||
|
||||
class LVQMLN(GLVQ):
|
||||
"""Learning Vector Quantization Multi-Layer Network.
|
||||
|
||||
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
|
||||
on the prototypes before computing the distances between them. This of
|
||||
course, means that the prototypes no longer live the input space, but
|
||||
rather in the embedding space.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
hparams,
|
||||
backbone_module=torch.nn.Identity,
|
||||
backbone_params={},
|
||||
**kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.backbone = backbone_module(**backbone_params)
|
||||
with torch.no_grad():
|
||||
protos = self.backbone(self.proto_layer()[0])
|
||||
self.proto_layer.load_state_dict({"_components": protos}, strict=False)
|
||||
|
||||
def forward(self, x):
|
||||
latent_protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
dis = euclidean_distance(latent_x, latent_protos)
|
||||
return dis
|
||||
|
||||
def predict_latent(self, x):
|
||||
"""Predict `x` assuming it is already embedded in the latent space."""
|
||||
with torch.no_grad():
|
||||
latent_protos, plabels = self.proto_layer()
|
||||
d = euclidean_distance(x, latent_protos)
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred.numpy()
|
7
prototorch/models/library/__init__.py
Normal file
7
prototorch/models/library/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .glvq import GLVQ
|
||||
from .gmlvq import GMLVQ
|
||||
|
||||
__all__ = [
|
||||
"GLVQ",
|
||||
"GMLVQ",
|
||||
]
|
35
prototorch/models/library/glvq.py
Normal file
35
prototorch/models/library/glvq.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from prototorch.models import (
|
||||
SimpleComparisonMixin,
|
||||
SingleLearningRateMixin,
|
||||
SupervisedArchitecture,
|
||||
WTACompetitionMixin,
|
||||
)
|
||||
from prototorch.models.architectures.loss import GLVQLossMixin
|
||||
|
||||
|
||||
class GLVQ(
|
||||
SupervisedArchitecture,
|
||||
SimpleComparisonMixin,
|
||||
GLVQLossMixin,
|
||||
WTACompetitionMixin,
|
||||
SingleLearningRateMixin,
|
||||
):
|
||||
"""
|
||||
Generalized Learning Vector Quantization (GLVQ)
|
||||
|
||||
A GLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class HyperParameters(
|
||||
SimpleComparisonMixin.HyperParameters,
|
||||
SingleLearningRateMixin.HyperParameters,
|
||||
GLVQLossMixin.HyperParameters,
|
||||
WTACompetitionMixin.HyperParameters,
|
||||
SupervisedArchitecture.HyperParameters,
|
||||
):
|
||||
"""
|
||||
No hyperparameters.
|
||||
"""
|
50
prototorch/models/library/gmlvq.py
Normal file
50
prototorch/models/library/gmlvq.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from prototorch.core.distances import omega_distance
|
||||
from prototorch.models import (
|
||||
GLVQLossMixin,
|
||||
MultipleLearningRateMixin,
|
||||
OmegaComparisonMixin,
|
||||
SupervisedArchitecture,
|
||||
WTACompetitionMixin,
|
||||
)
|
||||
|
||||
|
||||
class GMLVQ(
|
||||
SupervisedArchitecture,
|
||||
OmegaComparisonMixin,
|
||||
GLVQLossMixin,
|
||||
WTACompetitionMixin,
|
||||
MultipleLearningRateMixin,
|
||||
):
|
||||
"""
|
||||
Generalized Matrix Learning Vector Quantization (GMLVQ)
|
||||
|
||||
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
|
||||
"""
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(
|
||||
MultipleLearningRateMixin.HyperParameters,
|
||||
OmegaComparisonMixin.HyperParameters,
|
||||
GLVQLossMixin.HyperParameters,
|
||||
WTACompetitionMixin.HyperParameters,
|
||||
SupervisedArchitecture.HyperParameters,
|
||||
):
|
||||
"""
|
||||
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
|
||||
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
||||
"""
|
||||
comparison_fn: Callable = omega_distance
|
||||
comparison_args: dict = field(default_factory=dict)
|
||||
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
lr: dict = field(default_factory=lambda: dict(
|
||||
components_layer=0.1,
|
||||
_omega=0.5,
|
||||
))
|
@@ -1,69 +0,0 @@
|
||||
import torch
|
||||
from prototorch.components import Components
|
||||
from prototorch.components import initializers as cinit
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import NeuralGasEnergy
|
||||
|
||||
from .abstract import AbstractPrototypeModel
|
||||
|
||||
|
||||
class EuclideanDistance(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return euclidean_distance(x, y)
|
||||
|
||||
|
||||
class ConnectionTopology(torch.nn.Module):
|
||||
def __init__(self, agelimit, num_prototypes):
|
||||
super().__init__()
|
||||
self.agelimit = agelimit
|
||||
self.num_prototypes = num_prototypes
|
||||
|
||||
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
|
||||
self.age = torch.zeros_like(self.cmat)
|
||||
|
||||
def forward(self, d):
|
||||
order = torch.argsort(d, dim=1)
|
||||
|
||||
for element in order:
|
||||
i0, i1 = element[0], element[1]
|
||||
self.cmat[i0][i1] = 1
|
||||
self.age[i0][i1] = 0
|
||||
self.age[i0][self.cmat[i0] == 1] += 1
|
||||
self.cmat[i0][self.age[i0] > self.agelimit] = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return f"agelimit: {self.agelimit}"
|
||||
|
||||
|
||||
class NeuralGas(AbstractPrototypeModel):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Default Values
|
||||
self.hparams.setdefault("input_dim", 2)
|
||||
self.hparams.setdefault("agelimit", 10)
|
||||
self.hparams.setdefault("lm", 1)
|
||||
self.hparams.setdefault("prototype_initializer",
|
||||
cinit.ZerosInitializer(self.hparams.input_dim))
|
||||
|
||||
self.proto_layer = Components(
|
||||
self.hparams.num_prototypes,
|
||||
initializer=self.hparams.prototype_initializer)
|
||||
|
||||
self.distance_layer = EuclideanDistance()
|
||||
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
|
||||
self.topology_layer = ConnectionTopology(
|
||||
agelimit=self.hparams.agelimit,
|
||||
num_prototypes=self.hparams.num_prototypes,
|
||||
)
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
x = train_batch[0]
|
||||
protos = self.proto_layer()
|
||||
d = self.distance_layer(x, protos)
|
||||
cost, order = self.energy_layer(d)
|
||||
|
||||
self.topology_layer(d)
|
||||
return cost
|
@@ -1,295 +1,77 @@
|
||||
import os
|
||||
"""Visualization Callbacks."""
|
||||
|
||||
import warnings
|
||||
from typing import Sized
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.offsetbox import AnchoredText
|
||||
from prototorch.utils.celluloid import Camera
|
||||
from prototorch.utils.colors import color_scheme
|
||||
from prototorch.utils.utils import (gif_from_dir, make_directory,
|
||||
prettify_string)
|
||||
from prototorch.utils.colors import get_colors, get_legend_handles
|
||||
from prototorch.utils.utils import mesh2d
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
class VisWeights(pl.Callback):
|
||||
"""Abstract weight visualization callback."""
|
||||
def __init__(
|
||||
self,
|
||||
data=None,
|
||||
ignore_last_output_row=False,
|
||||
label_map=None,
|
||||
project_mesh=False,
|
||||
project_protos=False,
|
||||
voronoi=False,
|
||||
axis_off=True,
|
||||
cmap="viridis",
|
||||
show=True,
|
||||
display_logs=True,
|
||||
display_logs_settings={},
|
||||
pause_time=0.5,
|
||||
border=1,
|
||||
resolution=10,
|
||||
interval=False,
|
||||
save=False,
|
||||
snap=True,
|
||||
save_dir="./img",
|
||||
make_gif=False,
|
||||
make_mp4=False,
|
||||
verbose=True,
|
||||
dpi=500,
|
||||
fps=5,
|
||||
figsize=(11, 8.5), # standard paper in inches
|
||||
prefix="",
|
||||
distance_layer_index=-1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.data = data
|
||||
self.ignore_last_output_row = ignore_last_output_row
|
||||
self.label_map = label_map
|
||||
self.voronoi = voronoi
|
||||
self.axis_off = True
|
||||
self.project_mesh = project_mesh
|
||||
self.project_protos = project_protos
|
||||
self.cmap = cmap
|
||||
self.show = show
|
||||
self.display_logs = display_logs
|
||||
self.display_logs_settings = display_logs_settings
|
||||
self.pause_time = pause_time
|
||||
self.border = border
|
||||
self.resolution = resolution
|
||||
self.interval = interval
|
||||
self.save = save
|
||||
self.snap = snap
|
||||
self.save_dir = save_dir
|
||||
self.make_gif = make_gif
|
||||
self.make_mp4 = make_mp4
|
||||
self.verbose = verbose
|
||||
self.dpi = dpi
|
||||
self.fps = fps
|
||||
self.figsize = figsize
|
||||
self.prefix = prefix
|
||||
self.distance_layer_index = distance_layer_index
|
||||
self.title = "Weights Visualization"
|
||||
make_directory(self.save_dir)
|
||||
|
||||
def _skip_epoch(self, epoch):
|
||||
if self.interval:
|
||||
if epoch % self.interval != 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _clean_and_setup_ax(self):
|
||||
ax = self.ax
|
||||
if not self.snap:
|
||||
ax.cla()
|
||||
ax.set_title(self.title)
|
||||
if self.axis_off:
|
||||
ax.axis("off")
|
||||
|
||||
def _savefig(self, fignum, orientation="horizontal"):
|
||||
figname = f"{self.save_dir}/{self.prefix}{fignum:05d}.png"
|
||||
figsize = self.figsize
|
||||
if orientation == "vertical":
|
||||
figsize = figsize[::-1]
|
||||
elif orientation == "horizontal":
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
self.fig.set_size_inches(figsize, forward=False)
|
||||
self.fig.savefig(figname, dpi=self.dpi)
|
||||
|
||||
def _show_and_save(self, epoch):
|
||||
if self.show:
|
||||
plt.pause(self.pause_time)
|
||||
if self.save:
|
||||
self._savefig(epoch)
|
||||
if self.snap:
|
||||
self.camera.snap()
|
||||
|
||||
def _display_logs(self, ax, epoch, logs):
|
||||
if self.display_logs:
|
||||
settings = dict(
|
||||
loc="lower right",
|
||||
# padding between the text and bounding box
|
||||
pad=0.5,
|
||||
# padding between the bounding box and the axes
|
||||
borderpad=1.0,
|
||||
# https://matplotlib.org/api/text_api.html#matplotlib.text.Text
|
||||
prop=dict(
|
||||
fontfamily="monospace",
|
||||
fontweight="medium",
|
||||
fontsize=12,
|
||||
),
|
||||
)
|
||||
|
||||
# Override settings with self.display_logs_settings.
|
||||
settings = {**settings, **self.display_logs_settings}
|
||||
|
||||
log_string = f"""Epoch: {epoch:04d},
|
||||
val_loss: {logs.get('val_loss', np.nan):.03f},
|
||||
val_acc: {logs.get('val_acc', np.nan):.03f},
|
||||
loss: {logs.get('loss', np.nan):.03f},
|
||||
acc: {logs.get('acc', np.nan):.03f}
|
||||
"""
|
||||
log_string = prettify_string(log_string, end="")
|
||||
# https://matplotlib.org/api/offsetbox_api.html#matplotlib.offsetbox.AnchoredText
|
||||
anchored_text = AnchoredText(log_string, **settings)
|
||||
self.ax.add_artist(anchored_text)
|
||||
|
||||
def on_train_start(self, trainer, pl_module, logs={}):
|
||||
self.fig = plt.figure(self.title)
|
||||
self.fig.set_size_inches(self.figsize, forward=False)
|
||||
self.ax = self.fig.add_subplot(111)
|
||||
self.camera = Camera(self.fig)
|
||||
|
||||
def on_train_end(self, trainer, pl_module, logs={}):
|
||||
if self.make_gif:
|
||||
gif_from_dir(directory=self.save_dir,
|
||||
prefix=self.prefix,
|
||||
duration=1.0 / self.fps)
|
||||
if self.snap and self.make_mp4:
|
||||
animation = self.camera.animate()
|
||||
vid = os.path.join(self.save_dir, f"{self.prefix}animation.mp4")
|
||||
if self.verbose:
|
||||
print(f"Saving mp4 under {vid}.")
|
||||
animation.save(vid, fps=self.fps, dpi=self.dpi)
|
||||
|
||||
|
||||
class VisPointProtos(VisWeights):
|
||||
"""Visualization of prototypes.
|
||||
.. TODO::
|
||||
Still in Progress.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.title = "Point Prototypes Visualization"
|
||||
self.data_scatter_settings = {
|
||||
"marker": "o",
|
||||
"s": 30,
|
||||
"edgecolor": "k",
|
||||
"cmap": self.cmap,
|
||||
}
|
||||
self.protos_scatter_settings = {
|
||||
"marker": "D",
|
||||
"s": 50,
|
||||
"edgecolor": "k",
|
||||
"cmap": self.cmap,
|
||||
}
|
||||
|
||||
def on_epoch_start(self, trainer, pl_module, logs={}):
|
||||
epoch = trainer.current_epoch
|
||||
if self._skip_epoch(epoch):
|
||||
return True
|
||||
|
||||
self._clean_and_setup_ax()
|
||||
|
||||
protos = pl_module.prototypes
|
||||
labels = pl_module.proto_layer.prototype_labels.detach().cpu().numpy()
|
||||
|
||||
if self.project_protos:
|
||||
protos = self.model.projection(protos).numpy()
|
||||
|
||||
color_map = color_scheme(n=len(set(labels)),
|
||||
cmap=self.cmap,
|
||||
zero_indexed=True)
|
||||
# TODO Get rid of the assumption y values in [0, num_of_classes]
|
||||
label_colors = [color_map[l] for l in labels]
|
||||
|
||||
if self.data is not None:
|
||||
x, y = self.data
|
||||
# TODO Get rid of the assumption y values in [0, num_of_classes]
|
||||
y_colors = [color_map[l] for l in y]
|
||||
# x = self.model.projection(x)
|
||||
if not isinstance(x, np.ndarray):
|
||||
x = x.numpy()
|
||||
|
||||
# Plot data points.
|
||||
self.ax.scatter(x[:, 0],
|
||||
x[:, 1],
|
||||
c=y_colors,
|
||||
**self.data_scatter_settings)
|
||||
|
||||
# Paint decision regions.
|
||||
if self.voronoi:
|
||||
border = self.border
|
||||
resolution = self.resolution
|
||||
x = np.vstack((x, protos))
|
||||
x_min, x_max = x[:, 0].min(), x[:, 0].max()
|
||||
y_min, y_max = x[:, 1].min(), x[:, 1].max()
|
||||
x_min, x_max = x_min - border, x_max + border
|
||||
y_min, y_max = y_min - border, y_max + border
|
||||
try:
|
||||
xx, yy = np.meshgrid(
|
||||
np.arange(x_min, x_max, (x_max - x_min) / resolution),
|
||||
np.arange(y_min, y_max, (x_max - x_min) / resolution),
|
||||
)
|
||||
except ValueError as ve:
|
||||
print(ve)
|
||||
raise ValueError(f"x_min: {x_min}, x_max: {x_max}. "
|
||||
f"x_min - x_max is {x_max - x_min}.")
|
||||
except MemoryError as me:
|
||||
print(me)
|
||||
raise ValueError("Too many points. "
|
||||
"Try reducing the resolution.")
|
||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||
|
||||
# Predict mesh labels.
|
||||
if self.project_mesh:
|
||||
mesh_input = self.model.projection(mesh_input)
|
||||
|
||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
|
||||
# Plot voronoi regions.
|
||||
self.ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||
self.ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||
|
||||
# Plot prototypes.
|
||||
self.ax.scatter(protos[:, 0],
|
||||
protos[:, 1],
|
||||
c=label_colors,
|
||||
**self.protos_scatter_settings)
|
||||
|
||||
# self._show_and_save(epoch)
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module, logs={}):
|
||||
epoch = trainer.current_epoch
|
||||
self._display_logs(self.ax, epoch, logs)
|
||||
self._show_and_save(epoch)
|
||||
|
||||
|
||||
class Vis2DAbstract(pl.Callback):
|
||||
|
||||
def __init__(self,
|
||||
data,
|
||||
data=None,
|
||||
title="Prototype Visualization",
|
||||
cmap="viridis",
|
||||
border=1,
|
||||
resolution=50,
|
||||
xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2",
|
||||
legend_labels=None,
|
||||
border=0.1,
|
||||
resolution=100,
|
||||
flatten_data=True,
|
||||
axis_off=False,
|
||||
show_protos=True,
|
||||
show=True,
|
||||
tensorboard=False,
|
||||
show_last_only=False,
|
||||
pause_time=0.1,
|
||||
block=False):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(data, Dataset):
|
||||
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
||||
x = x.view(len(data), -1) # flatten
|
||||
if data:
|
||||
if isinstance(data, Dataset):
|
||||
if isinstance(data, Sized):
|
||||
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
||||
else:
|
||||
# TODO: Add support for non-sized datasets
|
||||
raise NotImplementedError(
|
||||
"Data must be a dataset with a __len__ method.")
|
||||
elif isinstance(data, DataLoader):
|
||||
x = torch.tensor([])
|
||||
y = torch.tensor([])
|
||||
for x_b, y_b in data:
|
||||
x = torch.cat([x, x_b])
|
||||
y = torch.cat([y, y_b])
|
||||
else:
|
||||
x, y = data
|
||||
|
||||
if flatten_data:
|
||||
x = x.reshape(len(x), -1)
|
||||
|
||||
self.x_train = x
|
||||
self.y_train = y
|
||||
else:
|
||||
x, y = data
|
||||
self.x_train = x
|
||||
self.y_train = y
|
||||
self.x_train = None
|
||||
self.y_train = None
|
||||
|
||||
self.title = title
|
||||
self.xlabel = xlabel
|
||||
self.ylabel = ylabel
|
||||
self.legend_labels = legend_labels
|
||||
self.fig = plt.figure(self.title)
|
||||
self.cmap = cmap
|
||||
self.border = border
|
||||
self.resolution = resolution
|
||||
self.axis_off = axis_off
|
||||
self.show_protos = show_protos
|
||||
self.show = show
|
||||
self.tensorboard = tensorboard
|
||||
self.show_last_only = show_last_only
|
||||
self.pause_time = pause_time
|
||||
@@ -298,27 +80,19 @@ class Vis2DAbstract(pl.Callback):
|
||||
def precheck(self, trainer):
|
||||
if self.show_last_only:
|
||||
if trainer.current_epoch != trainer.max_epochs - 1:
|
||||
return
|
||||
return False
|
||||
return True
|
||||
|
||||
def setup_ax(self, xlabel=None, ylabel=None):
|
||||
def setup_ax(self):
|
||||
ax = self.fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(self.title)
|
||||
ax.axis("off")
|
||||
if xlabel:
|
||||
ax.set_xlabel("Data dimension 1")
|
||||
if ylabel:
|
||||
ax.set_ylabel("Data dimension 2")
|
||||
ax.set_xlabel(self.xlabel)
|
||||
ax.set_ylabel(self.ylabel)
|
||||
if self.axis_off:
|
||||
ax.axis("off")
|
||||
return ax
|
||||
|
||||
def get_mesh_input(self, x):
|
||||
x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border
|
||||
y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border
|
||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / self.resolution),
|
||||
np.arange(y_min, y_max, 1 / self.resolution))
|
||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||
return mesh_input, xx, yy
|
||||
|
||||
def plot_data(self, ax, x, y):
|
||||
ax.scatter(
|
||||
x[:, 0],
|
||||
@@ -351,94 +125,135 @@ class Vis2DAbstract(pl.Callback):
|
||||
def log_and_display(self, trainer, pl_module):
|
||||
if self.tensorboard:
|
||||
self.add_to_tensorboard(trainer, pl_module)
|
||||
if not self.block:
|
||||
plt.pause(self.pause_time)
|
||||
else:
|
||||
plt.show(block=True)
|
||||
if self.show:
|
||||
if not self.block:
|
||||
plt.pause(self.pause_time)
|
||||
else:
|
||||
plt.show(block=self.block)
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
self.visualize(pl_module)
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
def visualize(self, pl_module):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VisGLVQ2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.precheck(trainer)
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
ax = self.setup_ax()
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
if x_train is not None:
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
|
||||
self.border, self.resolution)
|
||||
else:
|
||||
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
||||
_components = pl_module.proto_layer._components
|
||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||
y_pred = pl_module.predict(mesh_input)
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
|
||||
class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||
|
||||
def __init__(self, *args, map_protos=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.map_protos = map_protos
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.precheck(trainer)
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
|
||||
device = pl_module.device
|
||||
with torch.no_grad():
|
||||
x_train = pl_module.backbone(torch.Tensor(x_train).to(device))
|
||||
x_train = x_train.cpu().detach()
|
||||
if self.map_protos:
|
||||
protos = pl_module.backbone(torch.Tensor(protos)).detach()
|
||||
with torch.no_grad():
|
||||
protos = pl_module.backbone(torch.Tensor(protos).to(device))
|
||||
protos = protos.cpu().detach()
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
if self.show_protos:
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||
else:
|
||||
mesh_input, xx, yy = self.get_mesh_input(x_train)
|
||||
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
|
||||
_components = pl_module.proto_layer._components
|
||||
mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
||||
y_pred = pl_module.predict_latent(mesh_input,
|
||||
map_protos=self.map_protos)
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
class VisGMLVQ2D(Vis2DAbstract):
|
||||
|
||||
def __init__(self, *args, ev_proj=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ev_proj = ev_proj
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
device = pl_module.device
|
||||
omega = pl_module._omega.detach()
|
||||
lam = omega @ omega.T
|
||||
u, _, _ = torch.pca_lowrank(lam, q=2)
|
||||
with torch.no_grad():
|
||||
x_train = torch.Tensor(x_train).to(device)
|
||||
x_train = x_train @ u
|
||||
x_train = x_train.cpu().detach()
|
||||
if self.show_protos:
|
||||
with torch.no_grad():
|
||||
protos = torch.Tensor(protos).to(device)
|
||||
protos = protos @ u
|
||||
protos = protos.cpu().detach()
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
if self.show_protos:
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
|
||||
|
||||
class VisCBC2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.precheck(trainer)
|
||||
|
||||
def visualize(self, pl_module):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
protos = pl_module.components
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
self.plot_protos(ax, protos, "w")
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||
_components = pl_module.components_layer._components
|
||||
y_pred = pl_module.predict(
|
||||
torch.Tensor(mesh_input).type_as(_components))
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
|
||||
class VisNG2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.precheck(trainer)
|
||||
|
||||
def visualize(self, pl_module):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
protos = pl_module.prototypes
|
||||
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
||||
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, "w")
|
||||
|
||||
@@ -452,4 +267,97 @@ class VisNG2D(Vis2DAbstract):
|
||||
"k-",
|
||||
)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
class VisSpectralProtos(Vis2DAbstract):
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
ax = self.setup_ax()
|
||||
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
|
||||
for p, pl in zip(protos, plabels):
|
||||
ax.plot(p, c=colors[int(pl)])
|
||||
if self.legend_labels:
|
||||
handles = get_legend_handles(
|
||||
colors,
|
||||
self.legend_labels,
|
||||
marker="lines",
|
||||
)
|
||||
ax.legend(handles=handles)
|
||||
|
||||
|
||||
class VisImgComp(Vis2DAbstract):
|
||||
|
||||
def __init__(self,
|
||||
*args,
|
||||
random_data=0,
|
||||
dataformats="CHW",
|
||||
num_columns=2,
|
||||
add_embedding=False,
|
||||
embedding_data=100,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.random_data = random_data
|
||||
self.dataformats = dataformats
|
||||
self.num_columns = num_columns
|
||||
self.add_embedding = add_embedding
|
||||
self.embedding_data = embedding_data
|
||||
|
||||
def on_train_start(self, _, pl_module):
|
||||
if isinstance(pl_module.logger, TensorBoardLogger):
|
||||
tb = pl_module.logger.experiment
|
||||
|
||||
# Add embedding
|
||||
if self.add_embedding:
|
||||
if self.x_train is not None and self.y_train is not None:
|
||||
ind = np.random.choice(len(self.x_train),
|
||||
size=self.embedding_data,
|
||||
replace=False)
|
||||
data = self.x_train[ind]
|
||||
tb.add_embedding(data.view(len(ind), -1),
|
||||
label_img=data,
|
||||
global_step=None,
|
||||
tag="Data Embedding",
|
||||
metadata=self.y_train[ind],
|
||||
metadata_header=None)
|
||||
else:
|
||||
raise ValueError("No data for add embedding flag")
|
||||
|
||||
# Random Data
|
||||
if self.random_data:
|
||||
if self.x_train is not None:
|
||||
ind = np.random.choice(len(self.x_train),
|
||||
size=self.random_data,
|
||||
replace=False)
|
||||
data = self.x_train[ind]
|
||||
grid = torchvision.utils.make_grid(data,
|
||||
nrow=self.num_columns)
|
||||
tb.add_image(tag="Data",
|
||||
img_tensor=grid,
|
||||
global_step=None,
|
||||
dataformats=self.dataformats)
|
||||
else:
|
||||
raise ValueError("No data for random data flag")
|
||||
|
||||
else:
|
||||
warnings.warn(
|
||||
f"TensorBoardLogger is required, got {type(pl_module.logger)}")
|
||||
|
||||
def add_to_tensorboard(self, trainer, pl_module):
|
||||
tb = pl_module.logger.experiment
|
||||
|
||||
components = pl_module.components
|
||||
grid = torchvision.utils.make_grid(components, nrow=self.num_columns)
|
||||
tb.add_image(
|
||||
tag="Components",
|
||||
img_tensor=grid,
|
||||
global_step=trainer.current_epoch,
|
||||
dataformats=self.dataformats,
|
||||
)
|
||||
|
||||
def visualize(self, pl_module):
|
||||
if self.show:
|
||||
components = pl_module.components
|
||||
grid = torchvision.utils.make_grid(components,
|
||||
nrow=self.num_columns)
|
||||
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
||||
|
23
setup.cfg
Normal file
23
setup.cfg
Normal file
@@ -0,0 +1,23 @@
|
||||
[yapf]
|
||||
based_on_style = pep8
|
||||
spaces_before_comment = 2
|
||||
split_before_logical_operator = true
|
||||
|
||||
[pylint]
|
||||
disable =
|
||||
too-many-arguments,
|
||||
too-few-public-methods,
|
||||
fixme,
|
||||
|
||||
|
||||
[pycodestyle]
|
||||
max-line-length = 79
|
||||
|
||||
[isort]
|
||||
profile = hug
|
||||
src_paths = isort, test
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = True
|
||||
force_grid_wrap = 3
|
||||
use_parentheses = True
|
||||
line_length = 79
|
65
setup.py
65
setup.py
@@ -1,13 +1,17 @@
|
||||
"""
|
||||
_____ _ _______ _
|
||||
| __ \ | | |__ __| | |
|
||||
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
|
||||
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
|
||||
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|
||||
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|Plugin
|
||||
|
||||
######
|
||||
# # ##### #### ##### #### ##### #### ##### #### # #
|
||||
# # # # # # # # # # # # # # # # # #
|
||||
###### # # # # # # # # # # # # # ######
|
||||
# ##### # # # # # # # # ##### # # #
|
||||
# # # # # # # # # # # # # # # # #
|
||||
# # # #### # #### # #### # # #### # #Plugin
|
||||
|
||||
ProtoTorch models Plugin Package
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from pkg_resources import safe_name
|
||||
from setuptools import find_namespace_packages, setup
|
||||
|
||||
@@ -16,18 +20,43 @@ PLUGIN_NAME = "models"
|
||||
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
long_description = Path("README.md").read_text(encoding='utf8')
|
||||
|
||||
INSTALL_REQUIRES = ["prototorch>=0.4.1", "pytorch_lightning", "torchmetrics"]
|
||||
DEV = ["bumpversion"]
|
||||
EXAMPLES = ["matplotlib", "scikit-learn"]
|
||||
TESTS = ["codecov", "pytest"]
|
||||
ALL = DEV + EXAMPLES + TESTS
|
||||
INSTALL_REQUIRES = [
|
||||
"prototorch>=0.7.3",
|
||||
"pytorch_lightning>=1.6.0",
|
||||
"torchmetrics",
|
||||
"protobuf<3.20.0",
|
||||
]
|
||||
CLI = [
|
||||
"jsonargparse",
|
||||
]
|
||||
DEV = [
|
||||
"bumpversion",
|
||||
"pre-commit",
|
||||
]
|
||||
DOCS = [
|
||||
"recommonmark",
|
||||
"sphinx",
|
||||
"nbsphinx",
|
||||
"ipykernel",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib-katex",
|
||||
"sphinxcontrib-bibtex",
|
||||
]
|
||||
EXAMPLES = [
|
||||
"matplotlib",
|
||||
"scikit-learn",
|
||||
]
|
||||
TESTS = [
|
||||
"codecov",
|
||||
"pytest",
|
||||
]
|
||||
ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
|
||||
|
||||
setup(
|
||||
name=safe_name("prototorch_" + PLUGIN_NAME),
|
||||
version="0.1.6",
|
||||
version="1.0.0-a8",
|
||||
description="Pre-packaged prototype-based "
|
||||
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
||||
long_description=long_description,
|
||||
@@ -37,6 +66,7 @@ setup(
|
||||
url=PROJECT_URL,
|
||||
download_url=DOWNLOAD_URL,
|
||||
license="MIT",
|
||||
python_requires=">=3.7",
|
||||
install_requires=INSTALL_REQUIRES,
|
||||
extras_require={
|
||||
"dev": DEV,
|
||||
@@ -52,10 +82,11 @@ setup(
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Natural Language :: English",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
|
@@ -1,6 +0,0 @@
|
||||
import unittest
|
||||
|
||||
|
||||
class TestDummy(unittest.TestCase):
|
||||
def test_one(self):
|
||||
self.assertEqual(True, True)
|
35
tests/test_examples.sh
Executable file
35
tests/test_examples.sh
Executable file
@@ -0,0 +1,35 @@
|
||||
#! /bin/bash
|
||||
|
||||
|
||||
# Read Flags
|
||||
gpu=0
|
||||
while [ -n "$1" ]; do
|
||||
case "$1" in
|
||||
--gpu) gpu=1;;
|
||||
-g) gpu=1;;
|
||||
*) path=$1;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
python --version
|
||||
echo "Using GPU: " $gpu
|
||||
|
||||
# Loop
|
||||
failed=0
|
||||
|
||||
for example in $(find $path -maxdepth 1 -name "*.py")
|
||||
do
|
||||
echo -n "$x" $example '... '
|
||||
export DISPLAY= && python $example --fast_dev_run 1 --gpus $gpu &> run_log.txt
|
||||
if [[ $? -ne 0 ]]; then
|
||||
echo "FAILED!!"
|
||||
cat run_log.txt
|
||||
failed=1
|
||||
else
|
||||
echo "SUCCESS!"
|
||||
fi
|
||||
rm run_log.txt
|
||||
done
|
||||
|
||||
exit $failed
|
13
tests/test_models.py
Normal file
13
tests/test_models.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""prototorch.models test suite."""
|
||||
|
||||
import prototorch as pt
|
||||
from prototorch.models.library import GLVQ
|
||||
|
||||
|
||||
def test_glvq_model_build():
|
||||
hparams = GLVQ.HyperParameters(
|
||||
distribution=dict(num_classes=2, per_class=1),
|
||||
component_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
model = GLVQ(hparams=hparams)
|
Reference in New Issue
Block a user