Compare commits
50 Commits
refactor/s
...
feature/je
Author | SHA1 | Date | |
---|---|---|---|
|
0788718c31 | ||
|
4f5c4ebe8f | ||
|
ae2a8e54ef | ||
|
d9be100c1f | ||
|
9d1dc7320f | ||
|
d11ab71b7e | ||
|
59037e1a50 | ||
|
a19b99be82 | ||
|
f7e7558338 | ||
|
d57648f9d6 | ||
|
d24f580bf0 | ||
|
916973c3e8 | ||
|
b49b7a2d41 | ||
|
b6e8242383 | ||
|
cd616d11b9 | ||
|
afcfcb8973 | ||
|
bf03a45475 | ||
|
083b5c1597 | ||
|
7f0a8e9bce | ||
|
bf09ff8f7f | ||
|
c1d7cfee8f | ||
|
99be965581 | ||
|
fdb9a7c66d | ||
|
eb79b703d8 | ||
|
bc9a826b7d | ||
|
cfe09ec06b | ||
|
3d76dffe3c | ||
|
597c9fc1ee | ||
|
a8c74a1a6f | ||
|
f78ff1a464 | ||
|
5a3dbfac2e | ||
|
478a3c2cfe | ||
|
4520fdde8e | ||
|
b90044b86c | ||
|
a1310df4ee | ||
|
5dc66494ea | ||
|
74d420a77d | ||
|
6ffd14e85c | ||
|
40c1021c20 | ||
|
acf3272fd7 | ||
|
c73f8e7a28 | ||
|
bf23d5f7f8 | ||
|
bcde3f6ac8 | ||
|
d5229b1750 | ||
|
fc4b143fbb | ||
|
11cfa79746 | ||
|
d0ae94f2af | ||
|
2c908a8361 | ||
|
e4257ec1f1 | ||
|
aaad2b8626 |
@@ -1,10 +1,10 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.5.0
|
current_version = 0.7.1
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
serialize =
|
serialize = {major}.{minor}.{patch}
|
||||||
{major}.{minor}.{patch}
|
message = build: bump version {current_version} → {new_version}
|
||||||
|
|
||||||
[bumpversion:file:setup.py]
|
[bumpversion:file:setup.py]
|
||||||
|
|
||||||
|
7
.ci/python310.Dockerfile
Normal file
7
.ci/python310.Dockerfile
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
FROM python:3.9
|
||||||
|
|
||||||
|
RUN adduser --uid 1000 jenkins
|
||||||
|
|
||||||
|
USER jenkins
|
||||||
|
|
||||||
|
RUN mkdir -p /home/jenkins/.cache/pip
|
7
.ci/python36.Dockerfile
Normal file
7
.ci/python36.Dockerfile
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
FROM python:3.6
|
||||||
|
|
||||||
|
RUN adduser --uid 1000 jenkins
|
||||||
|
|
||||||
|
USER jenkins
|
||||||
|
|
||||||
|
RUN mkdir -p /home/jenkins/.cache/pip
|
19
.github/ISSUE_TEMPLATE/bug_report.md
vendored
19
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -10,21 +10,28 @@ assignees: ''
|
|||||||
**Describe the bug**
|
**Describe the bug**
|
||||||
A clear and concise description of what the bug is.
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
**To Reproduce**
|
**Steps to reproduce the behavior**
|
||||||
Steps to reproduce the behavior:
|
1. ...
|
||||||
1. Install Prototorch by running '...'
|
2. Run script '...' or this snippet:
|
||||||
2. Run script '...'
|
```python
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
...
|
||||||
|
```
|
||||||
3. See errors
|
3. See errors
|
||||||
|
|
||||||
**Expected behavior**
|
**Expected behavior**
|
||||||
A clear and concise description of what you expected to happen.
|
A clear and concise description of what you expected to happen.
|
||||||
|
|
||||||
|
**Observed behavior**
|
||||||
|
A clear and concise description of what actually happened.
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
If applicable, add screenshots to help explain your problem.
|
If applicable, add screenshots to help explain your problem.
|
||||||
|
|
||||||
**Desktop (please complete the following information):**
|
**System and version information**
|
||||||
- OS: [e.g. Ubuntu 20.10]
|
- OS: [e.g. Ubuntu 20.10]
|
||||||
- Prototorch Version: [e.g. v0.4.0]
|
- ProtoTorch Version: [e.g. 0.4.0]
|
||||||
- Python Version: [e.g. 3.9.5]
|
- Python Version: [e.g. 3.9.5]
|
||||||
|
|
||||||
**Additional context**
|
**Additional context**
|
||||||
|
4
.github/workflows/pythonapp.yml
vendored
4
.github/workflows/pythonapp.yml
vendored
@@ -16,10 +16,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python 3.8
|
- name: Set up Python 3.9
|
||||||
uses: actions/setup-python@v1
|
uses: actions/setup-python@v1
|
||||||
with:
|
with:
|
||||||
python-version: 3.8
|
python-version: 3.9
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@@ -155,4 +155,4 @@ dmypy.json
|
|||||||
reports
|
reports
|
||||||
artifacts
|
artifacts
|
||||||
examples/_*.py
|
examples/_*.py
|
||||||
examples/_*.ipynb
|
examples/_*.ipynb
|
||||||
|
53
.pre-commit-config.yaml
Normal file
53
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
# See https://pre-commit.com for more information
|
||||||
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
|
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.0.1
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-ast
|
||||||
|
- id: check-case-conflict
|
||||||
|
|
||||||
|
- repo: https://github.com/myint/autoflake
|
||||||
|
rev: v1.4
|
||||||
|
hooks:
|
||||||
|
- id: autoflake
|
||||||
|
|
||||||
|
- repo: http://github.com/PyCQA/isort
|
||||||
|
rev: 5.8.0
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
rev: v0.902
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
files: prototorch
|
||||||
|
additional_dependencies: [types-pkg_resources]
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||||
|
rev: v0.31.0
|
||||||
|
hooks:
|
||||||
|
- id: yapf
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||||
|
rev: v1.9.0
|
||||||
|
hooks:
|
||||||
|
- id: python-use-type-annotations
|
||||||
|
- id: python-no-log-warn
|
||||||
|
- id: python-check-blanket-noqa
|
||||||
|
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v2.19.4
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
|
||||||
|
- repo: https://github.com/si-cim/gitlint
|
||||||
|
rev: v0.15.2-unofficial
|
||||||
|
hooks:
|
||||||
|
- id: gitlint
|
||||||
|
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
@@ -19,7 +19,7 @@ formats: all
|
|||||||
|
|
||||||
# Optionally set the version of Python and requirements required to build your docs
|
# Optionally set the version of Python and requirements required to build your docs
|
||||||
python:
|
python:
|
||||||
version: 3.8
|
version: 3.9
|
||||||
install:
|
install:
|
||||||
- method: pip
|
- method: pip
|
||||||
path: .
|
path: .
|
||||||
|
36
.travis.yml
36
.travis.yml
@@ -1,36 +0,0 @@
|
|||||||
dist: bionic
|
|
||||||
sudo: false
|
|
||||||
language: python
|
|
||||||
python: 3.9
|
|
||||||
cache:
|
|
||||||
directories:
|
|
||||||
- "$HOME/.cache/pip"
|
|
||||||
- "./tests/artifacts"
|
|
||||||
- "$HOME/datasets"
|
|
||||||
install:
|
|
||||||
- pip install .[all] --progress-bar off
|
|
||||||
|
|
||||||
# Generate code coverage report
|
|
||||||
script:
|
|
||||||
- coverage run -m pytest
|
|
||||||
|
|
||||||
# Push the results to codecov
|
|
||||||
after_success:
|
|
||||||
- bash <(curl -s https://codecov.io/bash)
|
|
||||||
|
|
||||||
# Publish on PyPI
|
|
||||||
deploy:
|
|
||||||
provider: pypi
|
|
||||||
username: __token__
|
|
||||||
password:
|
|
||||||
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
|
||||||
on:
|
|
||||||
tags: true
|
|
||||||
skip_existing: true
|
|
||||||
|
|
||||||
# The password is encrypted with:
|
|
||||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
|
||||||
# See https://docs.travis-ci.com/user/deployment/pypi and
|
|
||||||
# https://github.com/travis-ci/travis.rb#installation
|
|
||||||
# for more details
|
|
||||||
# Note: The encrypt command does not work well in ZSH.
|
|
41
Jenkinsfile
vendored
Normal file
41
Jenkinsfile
vendored
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
pipeline {
|
||||||
|
agent none
|
||||||
|
stages {
|
||||||
|
stage('Unit Tests') {
|
||||||
|
parallel {
|
||||||
|
stage('3.6'){
|
||||||
|
agent{
|
||||||
|
dockerfile {
|
||||||
|
filename 'python36.Dockerfile'
|
||||||
|
dir '.ci'
|
||||||
|
args '-v pip-cache:/home/jenkins/.cache/pip'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
steps {
|
||||||
|
sh 'pip install pip --upgrade --progress-bar off'
|
||||||
|
sh 'pip install .[all] --progress-bar off'
|
||||||
|
sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
|
||||||
|
cobertura coberturaReportFile: 'reports/coverage.xml'
|
||||||
|
junit 'reports/**/*.xml'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stage('3.10'){
|
||||||
|
agent{
|
||||||
|
dockerfile {
|
||||||
|
filename 'python310.Dockerfile'
|
||||||
|
dir '.ci'
|
||||||
|
args '-v pip-cache:/home/jenkins/.cache/pip'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
steps {
|
||||||
|
sh 'pip install pip --upgrade --progress-bar off'
|
||||||
|
sh 'pip install .[all] --progress-bar off'
|
||||||
|
sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
|
||||||
|
cobertura coberturaReportFile: 'reports/coverage.xml'
|
||||||
|
junit 'reports/**/*.xml'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
20
README.md
20
README.md
@@ -2,13 +2,12 @@
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
[](https://travis-ci.org/si-cim/prototorch)
|
[](https://travis-ci.com/github/si-cim/prototorch)
|
||||||

|

|
||||||
[](https://github.com/si-cim/prototorch/releases)
|
[](https://github.com/si-cim/prototorch/releases)
|
||||||
[](https://pypi.org/project/prototorch/)
|
[](https://pypi.org/project/prototorch/)
|
||||||
[](https://codecov.io/gh/si-cim/prototorch)
|
[](https://codecov.io/gh/si-cim/prototorch)
|
||||||
[](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&utm_medium=referral&utm_content=si-cim/prototorch&utm_campaign=Badge_Grade)
|
[](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&utm_medium=referral&utm_content=si-cim/prototorch&utm_campaign=Badge_Grade)
|
||||||

|
|
||||||
[](https://github.com/si-cim/prototorch/blob/master/LICENSE)
|
[](https://github.com/si-cim/prototorch/blob/master/LICENSE)
|
||||||
|
|
||||||
*Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow)
|
*Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow)
|
||||||
@@ -48,6 +47,23 @@ pip install -e .[all]
|
|||||||
The documentation is available at <https://www.prototorch.ml/en/latest/>. Should
|
The documentation is available at <https://www.prototorch.ml/en/latest/>. Should
|
||||||
that link not work try <https://prototorch.readthedocs.io/en/latest/>.
|
that link not work try <https://prototorch.readthedocs.io/en/latest/>.
|
||||||
|
|
||||||
|
## Contribution
|
||||||
|
|
||||||
|
This repository contains definition for [git hooks](https://githooks.com).
|
||||||
|
[Pre-commit](https://pre-commit.com) is automatically installed as development
|
||||||
|
dependency with prototorch or you can install it manually with `pip install
|
||||||
|
pre-commit`.
|
||||||
|
|
||||||
|
Please install the hooks by running:
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
pre-commit install --hook-type commit-msg
|
||||||
|
```
|
||||||
|
before creating the first commit.
|
||||||
|
|
||||||
|
The commit will fail if the commit message does not follow the specification
|
||||||
|
provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
|
||||||
|
|
||||||
## Bibtex
|
## Bibtex
|
||||||
|
|
||||||
If you would like to cite the package, please use this:
|
If you would like to cite the package, please use this:
|
||||||
|
46
deprecated.travis.yml
Normal file
46
deprecated.travis.yml
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
dist: bionic
|
||||||
|
sudo: false
|
||||||
|
language: python
|
||||||
|
python:
|
||||||
|
- 3.9
|
||||||
|
- 3.8
|
||||||
|
- 3.7
|
||||||
|
- 3.6
|
||||||
|
cache:
|
||||||
|
directories:
|
||||||
|
- "$HOME/.cache/pip"
|
||||||
|
- "./tests/artifacts"
|
||||||
|
- "$HOME/datasets"
|
||||||
|
install:
|
||||||
|
- pip install .[all] --progress-bar off
|
||||||
|
|
||||||
|
# Generate code coverage report
|
||||||
|
script:
|
||||||
|
- coverage run -m pytest
|
||||||
|
|
||||||
|
# Push the results to codecov
|
||||||
|
after_success:
|
||||||
|
- bash <(curl -s https://codecov.io/bash)
|
||||||
|
|
||||||
|
# Publish on PyPI
|
||||||
|
jobs:
|
||||||
|
include:
|
||||||
|
- stage: build
|
||||||
|
python: 3.9
|
||||||
|
script: echo "Starting Pypi build"
|
||||||
|
deploy:
|
||||||
|
provider: pypi
|
||||||
|
username: __token__
|
||||||
|
distributions: "sdist bdist_wheel"
|
||||||
|
password:
|
||||||
|
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
||||||
|
on:
|
||||||
|
tags: true
|
||||||
|
skip_existing: true
|
||||||
|
|
||||||
|
# The password is encrypted with:
|
||||||
|
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
||||||
|
# See https://docs.travis-ci.com/user/deployment/pypi and
|
||||||
|
# https://github.com/travis-ci/travis.rb#installation
|
||||||
|
# for more details
|
||||||
|
# Note: The encrypt command does not work well in ZSH.
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.5.0"
|
release = "0.7.1"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
96
examples/cbc_iris.py
Normal file
96
examples/cbc_iris.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""ProtoTorch CBC example using 2D Iris data."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
|
||||||
|
class CBC(torch.nn.Module):
|
||||||
|
def __init__(self, data, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.components_layer = pt.components.ReasoningComponents(
|
||||||
|
distribution=[2, 1, 2],
|
||||||
|
components_initializer=pt.initializers.SSCI(data, noise=0.1),
|
||||||
|
reasonings_initializer=pt.initializers.PPRI(components_first=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
components, reasonings = self.components_layer()
|
||||||
|
sims = pt.similarities.euclidean_similarity(x, components)
|
||||||
|
probs = pt.competitions.cbcc(sims, reasonings)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
class VisCBC2D():
|
||||||
|
def __init__(self, model, data):
|
||||||
|
self.model = model
|
||||||
|
self.x_train, self.y_train = pt.utils.parse_data_arg(data)
|
||||||
|
self.title = "Components Visualization"
|
||||||
|
self.fig = plt.figure(self.title)
|
||||||
|
self.border = 0.1
|
||||||
|
self.resolution = 100
|
||||||
|
self.cmap = "viridis"
|
||||||
|
|
||||||
|
def on_epoch_end(self):
|
||||||
|
x_train, y_train = self.x_train, self.y_train
|
||||||
|
_components = self.model.components_layer._components.detach()
|
||||||
|
ax = self.fig.gca()
|
||||||
|
ax.cla()
|
||||||
|
ax.set_title(self.title)
|
||||||
|
ax.axis("off")
|
||||||
|
ax.scatter(
|
||||||
|
x_train[:, 0],
|
||||||
|
x_train[:, 1],
|
||||||
|
c=y_train,
|
||||||
|
cmap=self.cmap,
|
||||||
|
edgecolor="k",
|
||||||
|
marker="o",
|
||||||
|
s=30,
|
||||||
|
)
|
||||||
|
ax.scatter(
|
||||||
|
_components[:, 0],
|
||||||
|
_components[:, 1],
|
||||||
|
c="w",
|
||||||
|
cmap=self.cmap,
|
||||||
|
edgecolor="k",
|
||||||
|
marker="D",
|
||||||
|
s=50,
|
||||||
|
)
|
||||||
|
x = torch.vstack((x_train, _components))
|
||||||
|
mesh_input, xx, yy = pt.utils.mesh2d(x, self.border, self.resolution)
|
||||||
|
with torch.no_grad():
|
||||||
|
y_pred = self.model(
|
||||||
|
torch.Tensor(mesh_input).type_as(_components)).argmax(1)
|
||||||
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
plt.pause(0.2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
|
||||||
|
|
||||||
|
model = CBC(train_ds)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||||
|
criterion = pt.losses.MarginLoss(margin=0.1)
|
||||||
|
vis = VisCBC2D(model, train_ds)
|
||||||
|
|
||||||
|
for epoch in range(200):
|
||||||
|
correct = 0.0
|
||||||
|
for x, y in train_loader:
|
||||||
|
y_oh = torch.eye(3)[y]
|
||||||
|
y_pred = model(x)
|
||||||
|
loss = criterion(y_pred, y_oh).mean(0)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
correct += (y_pred.argmax(1) == y).float().sum(0)
|
||||||
|
|
||||||
|
acc = 100 * correct / len(train_ds)
|
||||||
|
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
||||||
|
vis.on_epoch_end()
|
@@ -1,120 +0,0 @@
|
|||||||
"""ProtoTorch GLVQ example using 2D Iris data."""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
|
||||||
from prototorch.functions.competitions import wtac
|
|
||||||
from prototorch.functions.distances import euclidean_distance
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
from torchinfo import summary
|
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
|
||||||
scaler = StandardScaler()
|
|
||||||
x_train, y_train = load_iris(return_X_y=True)
|
|
||||||
x_train = x_train[:, [0, 2]]
|
|
||||||
scaler.fit(x_train)
|
|
||||||
x_train = scaler.transform(x_train)
|
|
||||||
|
|
||||||
|
|
||||||
# Define the GLVQ model
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
"""GLVQ model for training on 2D Iris data."""
|
|
||||||
super().__init__()
|
|
||||||
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
|
||||||
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
|
|
||||||
self.proto_layer = LabeledComponents(
|
|
||||||
prototype_distribution,
|
|
||||||
prototype_initializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
prototypes, prototype_labels = self.proto_layer()
|
|
||||||
distances = euclidean_distance(x, prototypes)
|
|
||||||
return distances, prototype_labels
|
|
||||||
|
|
||||||
|
|
||||||
# Build the GLVQ model
|
|
||||||
model = Model()
|
|
||||||
|
|
||||||
# Print summary using torchinfo (might be buggy/incorrect)
|
|
||||||
print(summary(model))
|
|
||||||
|
|
||||||
# Optimize using SGD optimizer from `torch.optim`
|
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
||||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
|
||||||
|
|
||||||
x_in = torch.Tensor(x_train)
|
|
||||||
y_in = torch.Tensor(y_train)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
TITLE = "Prototype Visualization"
|
|
||||||
fig = plt.figure(TITLE)
|
|
||||||
for epoch in range(70):
|
|
||||||
# Compute loss
|
|
||||||
distances, prototype_labels = model(x_in)
|
|
||||||
loss = criterion([distances, prototype_labels], y_in)
|
|
||||||
|
|
||||||
# Compute Accuracy
|
|
||||||
with torch.no_grad():
|
|
||||||
predictions = wtac(distances, prototype_labels)
|
|
||||||
correct = predictions.eq(y_in.view_as(predictions)).sum().item()
|
|
||||||
acc = 100.0 * correct / len(x_train)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Optimizer step
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Get the prototypes form the model
|
|
||||||
prototypes = model.proto_layer.components.numpy()
|
|
||||||
if np.isnan(np.sum(prototypes)):
|
|
||||||
print("Stopping training because of `nan` in prototypes.")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
|
||||||
ax = fig.gca()
|
|
||||||
ax.cla()
|
|
||||||
ax.set_title(TITLE)
|
|
||||||
ax.set_xlabel("Data dimension 1")
|
|
||||||
ax.set_ylabel("Data dimension 2")
|
|
||||||
cmap = "viridis"
|
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
|
||||||
ax.scatter(
|
|
||||||
prototypes[:, 0],
|
|
||||||
prototypes[:, 1],
|
|
||||||
c=prototype_labels,
|
|
||||||
cmap=cmap,
|
|
||||||
edgecolor="k",
|
|
||||||
marker="D",
|
|
||||||
s=50,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Paint decision regions
|
|
||||||
x = np.vstack((x_train, prototypes))
|
|
||||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
|
||||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
|
||||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
|
||||||
np.arange(y_min, y_max, 1 / 50))
|
|
||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
||||||
|
|
||||||
torch_input = torch.Tensor(mesh_input)
|
|
||||||
d = model(torch_input)[0]
|
|
||||||
w_indices = torch.argmin(d, dim=1)
|
|
||||||
y_pred = torch.index_select(prototype_labels, 0, w_indices)
|
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
|
||||||
|
|
||||||
# Plot voronoi regions
|
|
||||||
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
|
|
||||||
|
|
||||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
|
||||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
|
||||||
|
|
||||||
plt.pause(0.1)
|
|
@@ -1,103 +0,0 @@
|
|||||||
"""ProtoTorch "siamese" GMLVQ example using Tecator."""
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import torch
|
|
||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
|
||||||
from prototorch.datasets.tecator import Tecator
|
|
||||||
from prototorch.functions.distances import sed
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
from prototorch.utils.colors import get_legend_handles
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# Prepare the dataset and dataloader
|
|
||||||
train_data = Tecator(root="./artifacts", train=True)
|
|
||||||
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
|
|
||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
"""GMLVQ model as a siamese network."""
|
|
||||||
super().__init__()
|
|
||||||
prototype_initializer = StratifiedMeanInitializer(train_loader)
|
|
||||||
prototype_distribution = {"num_classes": 2, "prototypes_per_class": 2}
|
|
||||||
|
|
||||||
self.proto_layer = LabeledComponents(
|
|
||||||
prototype_distribution,
|
|
||||||
prototype_initializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.omega = torch.nn.Linear(in_features=100,
|
|
||||||
out_features=100,
|
|
||||||
bias=False)
|
|
||||||
torch.nn.init.eye_(self.omega.weight)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
protos = self.proto_layer.components
|
|
||||||
plabels = self.proto_layer.component_labels
|
|
||||||
|
|
||||||
# Process `x` and `protos` through `omega`
|
|
||||||
x_map = self.omega(x)
|
|
||||||
protos_map = self.omega(protos)
|
|
||||||
|
|
||||||
# Compute distances and output
|
|
||||||
dis = sed(x_map, protos_map)
|
|
||||||
return dis, plabels
|
|
||||||
|
|
||||||
|
|
||||||
# Build the GLVQ model
|
|
||||||
model = Model()
|
|
||||||
|
|
||||||
# Print a summary of the model
|
|
||||||
print(model)
|
|
||||||
|
|
||||||
# Optimize using Adam optimizer from `torch.optim`
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0)
|
|
||||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1)
|
|
||||||
criterion = GLVQLoss(squashing="identity", beta=10)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
for epoch in range(150):
|
|
||||||
epoch_loss = 0.0 # zero-out epoch loss
|
|
||||||
optimizer.zero_grad() # zero-out gradients
|
|
||||||
for xb, yb in train_loader:
|
|
||||||
# Compute loss
|
|
||||||
distances, plabels = model(xb)
|
|
||||||
loss = criterion([distances, plabels], yb)
|
|
||||||
epoch_loss += loss.item()
|
|
||||||
# Backprop
|
|
||||||
loss.backward()
|
|
||||||
# Take a gradient descent step
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
lr = optimizer.param_groups[0]["lr"]
|
|
||||||
print(f"Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}")
|
|
||||||
|
|
||||||
# Get the omega matrix form the model
|
|
||||||
omega = model.omega.weight.data.numpy().T
|
|
||||||
|
|
||||||
# Visualize the lambda matrix
|
|
||||||
title = "Lambda Matrix Visualization"
|
|
||||||
fig = plt.figure(title)
|
|
||||||
ax = fig.gca()
|
|
||||||
ax.set_title(title)
|
|
||||||
im = ax.imshow(omega.dot(omega.T), cmap="viridis")
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# Get the prototypes form the model
|
|
||||||
protos = model.proto_layer.components.numpy()
|
|
||||||
plabels = model.proto_layer.component_labels.numpy()
|
|
||||||
|
|
||||||
# Visualize the prototypes
|
|
||||||
title = "Tecator Prototypes"
|
|
||||||
fig = plt.figure(title)
|
|
||||||
ax = fig.gca()
|
|
||||||
ax.set_title(title)
|
|
||||||
ax.set_xlabel("Spectral frequencies")
|
|
||||||
ax.set_ylabel("Absorption")
|
|
||||||
clabels = ["Class 0 - Low fat", "Class 1 - High fat"]
|
|
||||||
handles, colors = get_legend_handles(clabels, marker="line", zero_indexed=True)
|
|
||||||
for x, y in zip(protos, plabels):
|
|
||||||
ax.plot(x, c=colors[int(y)])
|
|
||||||
ax.legend(handles, clabels)
|
|
||||||
plt.show()
|
|
@@ -1,183 +0,0 @@
|
|||||||
"""
|
|
||||||
ProtoTorch GTLVQ example using MNIST data.
|
|
||||||
The GTLVQ is placed as an classification model on
|
|
||||||
top of a CNN, considered as featurer extractor.
|
|
||||||
Initialization of subpsace and prototypes in
|
|
||||||
Siamnese fashion
|
|
||||||
For more info about GTLVQ see:
|
|
||||||
DOI:10.1109/IJCNN.2016.7727534
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torchvision
|
|
||||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
from prototorch.modules.models import GTLVQ
|
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
# Parameters and options
|
|
||||||
num_epochs = 50
|
|
||||||
batch_size_train = 64
|
|
||||||
batch_size_test = 1000
|
|
||||||
learning_rate = 0.1
|
|
||||||
momentum = 0.5
|
|
||||||
log_interval = 10
|
|
||||||
cuda = "cuda:0"
|
|
||||||
random_seed = 1
|
|
||||||
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
# Configures reproducability
|
|
||||||
torch.manual_seed(random_seed)
|
|
||||||
np.random.seed(random_seed)
|
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
torchvision.datasets.MNIST(
|
|
||||||
"./files/",
|
|
||||||
train=True,
|
|
||||||
download=True,
|
|
||||||
transform=torchvision.transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
|
||||||
]),
|
|
||||||
),
|
|
||||||
batch_size=batch_size_train,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_loader = torch.utils.data.DataLoader(
|
|
||||||
torchvision.datasets.MNIST(
|
|
||||||
"./files/",
|
|
||||||
train=False,
|
|
||||||
download=True,
|
|
||||||
transform=torchvision.transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
|
||||||
]),
|
|
||||||
),
|
|
||||||
batch_size=batch_size_test,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Define the GLVQ model plus appropriate feature extractor
|
|
||||||
class CNNGTLVQ(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_classes,
|
|
||||||
subspace_data,
|
|
||||||
prototype_data,
|
|
||||||
tangent_projection_type="local",
|
|
||||||
prototypes_per_class=2,
|
|
||||||
bottleneck_dim=128,
|
|
||||||
):
|
|
||||||
super(CNNGTLVQ, self).__init__()
|
|
||||||
|
|
||||||
# Feature Extractor - Simple CNN
|
|
||||||
self.fe = nn.Sequential(
|
|
||||||
nn.Conv2d(1, 32, 3, 1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(32, 64, 3, 1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.MaxPool2d(2),
|
|
||||||
nn.Dropout(0.25),
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(9216, bottleneck_dim),
|
|
||||||
nn.Dropout(0.5),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.LayerNorm(bottleneck_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Forward pass of subspace and prototype initialization data through feature extractor
|
|
||||||
subspace_data = self.fe(subspace_data)
|
|
||||||
prototype_data[0] = self.fe(prototype_data[0])
|
|
||||||
|
|
||||||
# Initialization of GTLVQ
|
|
||||||
self.gtlvq = GTLVQ(
|
|
||||||
num_classes,
|
|
||||||
subspace_data,
|
|
||||||
prototype_data,
|
|
||||||
tangent_projection_type=tangent_projection_type,
|
|
||||||
feature_dim=bottleneck_dim,
|
|
||||||
prototypes_per_class=prototypes_per_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# Feature Extraction
|
|
||||||
x = self.fe(x)
|
|
||||||
|
|
||||||
# GTLVQ Forward pass
|
|
||||||
dis = self.gtlvq(x)
|
|
||||||
return dis
|
|
||||||
|
|
||||||
|
|
||||||
# Get init data
|
|
||||||
subspace_data = torch.cat(
|
|
||||||
[next(iter(train_loader))[0],
|
|
||||||
next(iter(test_loader))[0]])
|
|
||||||
prototype_data = next(iter(train_loader))
|
|
||||||
|
|
||||||
# Build the CNN GTLVQ model
|
|
||||||
model = CNNGTLVQ(
|
|
||||||
10,
|
|
||||||
subspace_data,
|
|
||||||
prototype_data,
|
|
||||||
tangent_projection_type="local",
|
|
||||||
bottleneck_dim=128,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
# Optimize using SGD optimizer from `torch.optim`
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
[{
|
|
||||||
"params": model.fe.parameters()
|
|
||||||
}, {
|
|
||||||
"params": model.gtlvq.parameters()
|
|
||||||
}],
|
|
||||||
lr=learning_rate,
|
|
||||||
)
|
|
||||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
|
||||||
model.train()
|
|
||||||
x_train, y_train = x_train.to(device), y_train.to(device)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
distances = model(x_train)
|
|
||||||
plabels = model.gtlvq.cls.component_labels.to(device)
|
|
||||||
|
|
||||||
# Compute loss.
|
|
||||||
loss = criterion([distances, plabels], y_train)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
|
|
||||||
model.gtlvq.orthogonalize_subspace()
|
|
||||||
|
|
||||||
if batch_idx % log_interval == 0:
|
|
||||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
|
||||||
print(
|
|
||||||
f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
|
||||||
Train Acc: {acc.item():02.02f}")
|
|
||||||
|
|
||||||
# Test
|
|
||||||
with torch.no_grad():
|
|
||||||
model.eval()
|
|
||||||
correct = 0
|
|
||||||
total = 0
|
|
||||||
for x_test, y_test in test_loader:
|
|
||||||
x_test, y_test = x_test.to(device), y_test.to(device)
|
|
||||||
test_distances = model(torch.tensor(x_test))
|
|
||||||
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
|
|
||||||
i = torch.argmin(test_distances, 1)
|
|
||||||
correct += torch.sum(y_test == test_plabels[i])
|
|
||||||
total += y_test.size(0)
|
|
||||||
print("Accuracy of the network on the test images: %d %%" %
|
|
||||||
(torch.true_divide(correct, total) * 100))
|
|
||||||
|
|
||||||
# Save the model
|
|
||||||
PATH = "./glvq_mnist_model.pth"
|
|
||||||
torch.save(model.state_dict(), PATH)
|
|
@@ -1,108 +0,0 @@
|
|||||||
"""ProtoTorch LGMLVQ example using 2D Iris data."""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
|
||||||
from prototorch.functions.competitions import stratified_min
|
|
||||||
from prototorch.functions.distances import lomega_distance
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
from sklearn.metrics import accuracy_score
|
|
||||||
|
|
||||||
# Prepare training data
|
|
||||||
x_train, y_train = load_iris(True)
|
|
||||||
x_train = x_train[:, [0, 2]]
|
|
||||||
|
|
||||||
|
|
||||||
# Define the model
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
"""Local-GMLVQ model."""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
|
||||||
prototype_distribution = [1, 2, 2]
|
|
||||||
self.proto_layer = LabeledComponents(
|
|
||||||
prototype_distribution,
|
|
||||||
prototype_initializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
omegas = torch.eye(2, 2).repeat(5, 1, 1)
|
|
||||||
self.omegas = torch.nn.Parameter(omegas)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
protos, plabels = self.proto_layer()
|
|
||||||
omegas = self.omegas
|
|
||||||
dis = lomega_distance(x, protos, omegas)
|
|
||||||
return dis, plabels
|
|
||||||
|
|
||||||
|
|
||||||
# Build the model
|
|
||||||
model = Model()
|
|
||||||
|
|
||||||
# Optimize using Adam optimizer from `torch.optim`
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
|
||||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
|
||||||
|
|
||||||
x_in = torch.Tensor(x_train)
|
|
||||||
y_in = torch.Tensor(y_train)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
title = "Prototype Visualization"
|
|
||||||
fig = plt.figure(title)
|
|
||||||
for epoch in range(100):
|
|
||||||
# Compute loss
|
|
||||||
dis, plabels = model(x_in)
|
|
||||||
loss = criterion([dis, plabels], y_in)
|
|
||||||
y_pred = np.argmin(stratified_min(dis, plabels).detach().numpy(), axis=1)
|
|
||||||
acc = accuracy_score(y_train, y_pred)
|
|
||||||
log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} "
|
|
||||||
log_string += f"Acc: {acc * 100:05.02f}%"
|
|
||||||
print(log_string)
|
|
||||||
|
|
||||||
# Take a gradient descent step
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Get the prototypes form the model
|
|
||||||
protos = model.proto_layer.components.numpy()
|
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
|
||||||
ax = fig.gca()
|
|
||||||
ax.cla()
|
|
||||||
ax.set_title(title)
|
|
||||||
ax.set_xlabel("Data dimension 1")
|
|
||||||
ax.set_ylabel("Data dimension 2")
|
|
||||||
cmap = "viridis"
|
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
|
||||||
ax.scatter(
|
|
||||||
protos[:, 0],
|
|
||||||
protos[:, 1],
|
|
||||||
c=plabels,
|
|
||||||
cmap=cmap,
|
|
||||||
edgecolor="k",
|
|
||||||
marker="D",
|
|
||||||
s=50,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Paint decision regions
|
|
||||||
x = np.vstack((x_train, protos))
|
|
||||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
|
||||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
|
||||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
|
||||||
np.arange(y_min, y_max, 1 / 50))
|
|
||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
||||||
|
|
||||||
d, plabels = model(torch.Tensor(mesh_input))
|
|
||||||
y_pred = np.argmin(stratified_min(d, plabels).detach().numpy(), axis=1)
|
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
|
||||||
|
|
||||||
# Plot voronoi regions
|
|
||||||
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
|
|
||||||
|
|
||||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
|
||||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
|
||||||
|
|
||||||
plt.pause(0.1)
|
|
@@ -1,6 +1,7 @@
|
|||||||
"""ProtoTorch package"""
|
"""ProtoTorch package"""
|
||||||
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ from .core import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Core Setup
|
# Core Setup
|
||||||
__version__ = "0.5.0"
|
__version__ = "0.7.1"
|
||||||
|
|
||||||
__all_core__ = [
|
__all_core__ = [
|
||||||
"competitions",
|
"competitions",
|
||||||
@@ -39,7 +40,7 @@ __all_core__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Plugin Loader
|
# Plugin Loader
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__)
|
__path__: List[str] = pkgutil.extend_path(__path__, __name__)
|
||||||
|
|
||||||
|
|
||||||
def discover_plugins():
|
def discover_plugins():
|
||||||
|
@@ -3,8 +3,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def wtac(distances: torch.Tensor,
|
def wtac(distances: torch.Tensor, labels: torch.LongTensor):
|
||||||
labels: torch.LongTensor) -> (torch.LongTensor):
|
|
||||||
"""Winner-Takes-All-Competition.
|
"""Winner-Takes-All-Competition.
|
||||||
|
|
||||||
Returns the labels corresponding to the winners.
|
Returns the labels corresponding to the winners.
|
||||||
@@ -15,9 +14,7 @@ def wtac(distances: torch.Tensor,
|
|||||||
return winning_labels
|
return winning_labels
|
||||||
|
|
||||||
|
|
||||||
def knnc(distances: torch.Tensor,
|
def knnc(distances: torch.Tensor, labels: torch.LongTensor, k: int = 1):
|
||||||
labels: torch.LongTensor,
|
|
||||||
k: int = 1) -> (torch.LongTensor):
|
|
||||||
"""K-Nearest-Neighbors-Competition.
|
"""K-Nearest-Neighbors-Competition.
|
||||||
|
|
||||||
Returns the labels corresponding to the winners.
|
Returns the labels corresponding to the winners.
|
||||||
@@ -51,7 +48,7 @@ class WTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, distances, labels):
|
def forward(self, distances, labels): # pylint: disable=no-self-use
|
||||||
return wtac(distances, labels)
|
return wtac(distances, labels)
|
||||||
|
|
||||||
|
|
||||||
@@ -61,7 +58,7 @@ class LTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, probs, labels):
|
def forward(self, probs, labels): # pylint: disable=no-self-use
|
||||||
return wtac(-1.0 * probs, labels)
|
return wtac(-1.0 * probs, labels)
|
||||||
|
|
||||||
|
|
||||||
@@ -88,5 +85,5 @@ class CBCC(torch.nn.Module):
|
|||||||
Thin wrapper over the `cbcc` function.
|
Thin wrapper over the `cbcc` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, detections, reasonings):
|
def forward(self, detections, reasonings): # pylint: disable=no-self-use
|
||||||
return cbcc(detections, reasonings)
|
return cbcc(detections, reasonings)
|
||||||
|
@@ -86,8 +86,8 @@ class AbstractComponents(torch.nn.Module):
|
|||||||
class Components(AbstractComponents):
|
class Components(AbstractComponents):
|
||||||
"""A set of adaptable Tensors."""
|
"""A set of adaptable Tensors."""
|
||||||
def __init__(self, num_components: int,
|
def __init__(self, num_components: int,
|
||||||
initializer: AbstractComponentsInitializer, **kwargs):
|
initializer: AbstractComponentsInitializer):
|
||||||
super().__init__(**kwargs)
|
super().__init__()
|
||||||
self.add_components(num_components, initializer)
|
self.add_components(num_components, initializer)
|
||||||
|
|
||||||
def add_components(self, num_components: int,
|
def add_components(self, num_components: int,
|
||||||
@@ -154,9 +154,8 @@ class Labels(AbstractLabels):
|
|||||||
"""A set of standalone labels."""
|
"""A set of standalone labels."""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
distribution: Union[dict, list, tuple],
|
distribution: Union[dict, list, tuple],
|
||||||
initializer: AbstractLabelsInitializer = LabelsInitializer(),
|
initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||||
**kwargs):
|
super().__init__()
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.add_labels(distribution, initializer)
|
self.add_labels(distribution, initializer)
|
||||||
|
|
||||||
def add_labels(
|
def add_labels(
|
||||||
@@ -184,13 +183,11 @@ class Labels(AbstractLabels):
|
|||||||
class LabeledComponents(AbstractComponents):
|
class LabeledComponents(AbstractComponents):
|
||||||
"""A set of adaptable components and corresponding unadaptable labels."""
|
"""A set of adaptable components and corresponding unadaptable labels."""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
distribution: Union[dict, list, tuple],
|
distribution: Union[dict, list, tuple],
|
||||||
components_initializer: AbstractComponentsInitializer,
|
components_initializer: AbstractComponentsInitializer,
|
||||||
labels_initializer: AbstractLabelsInitializer = LabelsInitializer(
|
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||||
),
|
super().__init__()
|
||||||
**kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.add_components(distribution, components_initializer,
|
self.add_components(distribution, components_initializer,
|
||||||
labels_initializer)
|
labels_initializer)
|
||||||
|
|
||||||
@@ -252,12 +249,14 @@ class Reasonings(torch.nn.Module):
|
|||||||
The `reasonings` tensor is of shape [num_components, num_classes, 2].
|
The `reasonings` tensor is of shape [num_components, num_classes, 2].
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(
|
||||||
distribution: Union[dict, list, tuple],
|
self,
|
||||||
initializer:
|
distribution: Union[dict, list, tuple],
|
||||||
AbstractReasoningsInitializer = RandomReasoningsInitializer(),
|
initializer:
|
||||||
**kwargs):
|
AbstractReasoningsInitializer = RandomReasoningsInitializer(),
|
||||||
super().__init__(**kwargs)
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.add_reasonings(distribution, initializer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
@@ -295,7 +294,7 @@ class Reasonings(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ReasoningComponents(AbstractComponents):
|
class ReasoningComponents(AbstractComponents):
|
||||||
"""A set of components and a corresponding adapatable reasoning matrices.
|
r"""A set of components and a corresponding adapatable reasoning matrices.
|
||||||
|
|
||||||
Every component has its own reasoning matrix.
|
Every component has its own reasoning matrix.
|
||||||
|
|
||||||
@@ -310,13 +309,12 @@ class ReasoningComponents(AbstractComponents):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
distribution: Union[dict, list, tuple],
|
distribution: Union[dict, list, tuple],
|
||||||
components_initializer: AbstractComponentsInitializer,
|
components_initializer: AbstractComponentsInitializer,
|
||||||
reasonings_initializer:
|
reasonings_initializer:
|
||||||
AbstractReasoningsInitializer = PurePositiveReasoningsInitializer(),
|
AbstractReasoningsInitializer = PurePositiveReasoningsInitializer()):
|
||||||
**kwargs):
|
super().__init__()
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.add_components(distribution, components_initializer,
|
self.add_components(distribution, components_initializer,
|
||||||
reasonings_initializer)
|
reasonings_initializer)
|
||||||
|
|
||||||
|
@@ -41,9 +41,6 @@ def euclidean_distance_v2(x, y):
|
|||||||
# batch diagonal. See:
|
# batch diagonal. See:
|
||||||
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||||
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||||
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
|
||||||
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
|
||||||
# print(f"{distances.shape=}") # (nx, ny)
|
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
@@ -3,7 +3,11 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Union
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -110,9 +114,9 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data: torch.TensorType,
|
data: torch.Tensor,
|
||||||
noise: float = 0.0,
|
noise: float = 0.0,
|
||||||
transform: callable = torch.nn.Identity()):
|
transform: Callable = torch.nn.Identity()):
|
||||||
self.data = data
|
self.data = data
|
||||||
self.noise = noise
|
self.noise = noise
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
@@ -151,14 +155,14 @@ class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
|||||||
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
"""Generate components by computing the mean of the provided data."""
|
"""Generate components by computing the mean of the provided data."""
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
mean = torch.mean(self.data, dim=0)
|
mean = self.data.mean(dim=0)
|
||||||
repeat_dim = [num_components] + [1] * len(mean.shape)
|
repeat_dim = [num_components] + [1] * len(mean.shape)
|
||||||
samples = mean.repeat(repeat_dim)
|
samples = mean.repeat(repeat_dim)
|
||||||
components = self.generate_end_hook(samples)
|
components = self.generate_end_hook(samples)
|
||||||
return components
|
return components
|
||||||
|
|
||||||
|
|
||||||
class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
|
class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
|
||||||
"""Abstract class for all class-aware components initializers.
|
"""Abstract class for all class-aware components initializers.
|
||||||
|
|
||||||
Components generated by class-aware components initializers inherit the shape
|
Components generated by class-aware components initializers inherit the shape
|
||||||
@@ -171,13 +175,18 @@ class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
data,
|
data,
|
||||||
noise: float = 0.0,
|
noise: float = 0.0,
|
||||||
transform: callable = torch.nn.Identity()):
|
transform: Callable = torch.nn.Identity()):
|
||||||
self.data, self.targets = parse_data_arg(data)
|
self.data, self.targets = parse_data_arg(data)
|
||||||
self.noise = noise
|
self.noise = noise
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.clabels = torch.unique(self.targets).int().tolist()
|
self.clabels = torch.unique(self.targets).int().tolist()
|
||||||
self.num_classes = len(self.clabels)
|
self.num_classes = len(self.clabels)
|
||||||
|
|
||||||
|
def generate_end_hook(self, samples):
|
||||||
|
drift = torch.rand_like(samples) * self.noise
|
||||||
|
components = self.transform(samples + drift)
|
||||||
|
return components
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
...
|
...
|
||||||
@@ -200,7 +209,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
|||||||
"""Abstract class for all stratified components initializers."""
|
"""Abstract class for all stratified components initializers."""
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def subinit_type(self) -> AbstractDataAwareCompInitializer:
|
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
@@ -276,10 +285,10 @@ class LabelsInitializer(AbstractLabelsInitializer):
|
|||||||
"""Generate labels from `distribution`."""
|
"""Generate labels from `distribution`."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
distribution = parse_distribution(distribution)
|
distribution = parse_distribution(distribution)
|
||||||
labels = []
|
labels_list = []
|
||||||
for k, v in distribution.items():
|
for k, v in distribution.items():
|
||||||
labels.extend([k] * v)
|
labels_list.extend([k] * v)
|
||||||
labels = torch.LongTensor(labels)
|
labels = torch.LongTensor(labels_list)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@@ -294,17 +303,18 @@ class OneHotLabelsInitializer(LabelsInitializer):
|
|||||||
|
|
||||||
|
|
||||||
# Reasonings
|
# Reasonings
|
||||||
|
def compute_distribution_shape(distribution):
|
||||||
|
distribution = parse_distribution(distribution)
|
||||||
|
num_components = sum(distribution.values())
|
||||||
|
num_classes = len(distribution.keys())
|
||||||
|
return (num_components, num_classes, 2)
|
||||||
|
|
||||||
|
|
||||||
class AbstractReasoningsInitializer(ABC):
|
class AbstractReasoningsInitializer(ABC):
|
||||||
"""Abstract class for all reasonings initializers."""
|
"""Abstract class for all reasonings initializers."""
|
||||||
def __init__(self, components_first: bool = True):
|
def __init__(self, components_first: bool = True):
|
||||||
self.components_first = components_first
|
self.components_first = components_first
|
||||||
|
|
||||||
def compute_shape(self, distribution):
|
|
||||||
distribution = parse_distribution(distribution)
|
|
||||||
num_components = sum(distribution.values())
|
|
||||||
num_classes = len(distribution.keys())
|
|
||||||
return (num_components, num_classes, 2)
|
|
||||||
|
|
||||||
def generate_end_hook(self, reasonings):
|
def generate_end_hook(self, reasonings):
|
||||||
if not self.components_first:
|
if not self.components_first:
|
||||||
reasonings = reasonings.permute(2, 1, 0)
|
reasonings = reasonings.permute(2, 1, 0)
|
||||||
@@ -340,7 +350,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with zeros."""
|
"""Reasonings are all initialized with zeros."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.zeros(*shape)
|
reasonings = torch.zeros(*shape)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@@ -349,7 +359,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with ones."""
|
"""Reasonings are all initialized with ones."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.ones(*shape)
|
reasonings = torch.ones(*shape)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@@ -363,7 +373,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
self.maximum = maximum
|
self.maximum = maximum
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@@ -372,7 +382,8 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Each component reasons positively for exactly one class."""
|
"""Each component reasons positively for exactly one class."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
num_components, num_classes, _ = compute_distribution_shape(
|
||||||
|
distribution)
|
||||||
A = OneHotLabelsInitializer().generate(distribution)
|
A = OneHotLabelsInitializer().generate(distribution)
|
||||||
B = torch.zeros(num_components, num_classes)
|
B = torch.zeros(num_components, num_classes)
|
||||||
reasonings = torch.stack([A, B], dim=-1)
|
reasonings = torch.stack([A, B], dim=-1)
|
||||||
@@ -425,6 +436,33 @@ class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
|||||||
return self.generate_end_hook(weights)
|
return self.generate_end_hook(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
||||||
|
"""Abstract class for all data-aware linear transform initializers."""
|
||||||
|
def __init__(self,
|
||||||
|
data: torch.Tensor,
|
||||||
|
noise: float = 0.0,
|
||||||
|
transform: Callable = torch.nn.Identity(),
|
||||||
|
out_dim_first: bool = False):
|
||||||
|
super().__init__(out_dim_first)
|
||||||
|
self.data = data
|
||||||
|
self.noise = noise
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def generate_end_hook(self, weights: torch.Tensor):
|
||||||
|
drift = torch.rand_like(weights) * self.noise
|
||||||
|
weights = self.transform(weights + drift)
|
||||||
|
if self.out_dim_first:
|
||||||
|
weights = weights.permute(1, 0)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
|
||||||
|
"""Initialize a matrix with Eigenvectors from the data."""
|
||||||
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
|
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
|
||||||
|
return self.generate_end_hook(weights)
|
||||||
|
|
||||||
|
|
||||||
# Aliases - Components
|
# Aliases - Components
|
||||||
CACI = ClassAwareCompInitializer
|
CACI = ClassAwareCompInitializer
|
||||||
DACI = DataAwareCompInitializer
|
DACI = DataAwareCompInitializer
|
||||||
@@ -456,3 +494,4 @@ ZRI = ZerosReasoningsInitializer
|
|||||||
Eye = EyeTransformInitializer
|
Eye = EyeTransformInitializer
|
||||||
OLTI = OnesLinearTransformInitializer
|
OLTI = OnesLinearTransformInitializer
|
||||||
ZLTI = ZerosLinearTransformInitializer
|
ZLTI = ZerosLinearTransformInitializer
|
||||||
|
PCALTI = PCALinearTransformInitializer
|
||||||
|
@@ -106,17 +106,16 @@ def margin_loss(y_pred, y_true, margin=0.3):
|
|||||||
|
|
||||||
|
|
||||||
class GLVQLoss(torch.nn.Module):
|
class GLVQLoss(torch.nn.Module):
|
||||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.squashing = get_activation(squashing)
|
self.transfer_fn = get_activation(transfer_fn)
|
||||||
self.beta = torch.tensor(beta)
|
self.beta = torch.tensor(beta)
|
||||||
|
|
||||||
def forward(self, outputs, targets):
|
def forward(self, outputs, targets, plabels):
|
||||||
distances, plabels = outputs
|
mu = glvq_loss(outputs, targets, prototype_labels=plabels)
|
||||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
|
||||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
return batch_loss.sum()
|
||||||
return torch.sum(batch_loss, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||||
|
@@ -82,23 +82,23 @@ def stratified_prod_pooling(values: torch.Tensor,
|
|||||||
|
|
||||||
class StratifiedSumPooling(torch.nn.Module):
|
class StratifiedSumPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_sum_pooling(values, labels)
|
return stratified_sum_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedProdPooling(torch.nn.Module):
|
class StratifiedProdPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_prod_pooling(values, labels)
|
return stratified_prod_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMinPooling(torch.nn.Module):
|
class StratifiedMinPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_min_pooling(values, labels)
|
return stratified_min_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMaxPooling(torch.nn.Module):
|
class StratifiedMaxPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_max_pooling(values, labels)
|
return stratified_max_pooling(values, labels)
|
||||||
|
@@ -11,13 +11,12 @@ from .initializers import (
|
|||||||
|
|
||||||
class LinearTransform(torch.nn.Module):
|
class LinearTransform(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_dim: int,
|
in_dim: int,
|
||||||
out_dim: int,
|
out_dim: int,
|
||||||
initializer:
|
initializer:
|
||||||
AbstractLinearTransformInitializer = EyeTransformInitializer(),
|
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
||||||
**kwargs):
|
super().__init__()
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.set_weights(in_dim, out_dim, initializer)
|
self.set_weights(in_dim, out_dim, initializer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -37,7 +36,7 @@ class LinearTransform(torch.nn.Module):
|
|||||||
self._register_weights(weights)
|
self._register_weights(weights)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x @ self.weights.T
|
return x @ self.weights
|
||||||
|
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
"""ProtoTorch datasets"""
|
"""ProtoTorch datasets"""
|
||||||
|
|
||||||
from .abstract import NumpyDataset
|
from .abstract import CSVDataset, NumpyDataset
|
||||||
from .sklearn import (
|
from .sklearn import (
|
||||||
Blobs,
|
Blobs,
|
||||||
Circles,
|
Circles,
|
||||||
@@ -10,3 +10,4 @@ from .sklearn import (
|
|||||||
)
|
)
|
||||||
from .spiral import Spiral
|
from .spiral import Spiral
|
||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
from .xor import XOR
|
||||||
|
@@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset):
|
|||||||
self.targets = torch.LongTensor(targets)
|
self.targets = torch.LongTensor(targets)
|
||||||
tensors = [self.data, self.targets]
|
tensors = [self.data, self.targets]
|
||||||
super().__init__(*tensors)
|
super().__init__(*tensors)
|
||||||
|
|
||||||
|
|
||||||
|
class CSVDataset(NumpyDataset):
|
||||||
|
"""Create a Dataset from a CSV file."""
|
||||||
|
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
|
||||||
|
raw = np.genfromtxt(
|
||||||
|
filepath,
|
||||||
|
delimiter=delimiter,
|
||||||
|
skip_header=skip_header,
|
||||||
|
)
|
||||||
|
data = np.delete(raw, 1, target_col)
|
||||||
|
targets = raw[:, target_col]
|
||||||
|
super().__init__(data, targets)
|
||||||
|
@@ -8,11 +8,11 @@ URL:
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from prototorch.datasets.abstract import NumpyDataset
|
|
||||||
|
|
||||||
from sklearn.datasets import (load_iris, make_blobs, make_circles,
|
from sklearn.datasets import (load_iris, make_blobs, make_circles,
|
||||||
make_classification, make_moons)
|
make_classification, make_moons)
|
||||||
|
|
||||||
|
from prototorch.datasets.abstract import NumpyDataset
|
||||||
|
|
||||||
|
|
||||||
class Iris(NumpyDataset):
|
class Iris(NumpyDataset):
|
||||||
"""Iris Dataset by Ronald Fisher introduced in 1936.
|
"""Iris Dataset by Ronald Fisher introduced in 1936.
|
||||||
|
@@ -40,9 +40,10 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from prototorch.datasets.abstract import ProtoDataset
|
|
||||||
from torchvision.datasets.utils import download_file_from_google_drive
|
from torchvision.datasets.utils import download_file_from_google_drive
|
||||||
|
|
||||||
|
from prototorch.datasets.abstract import ProtoDataset
|
||||||
|
|
||||||
|
|
||||||
class Tecator(ProtoDataset):
|
class Tecator(ProtoDataset):
|
||||||
"""
|
"""
|
||||||
|
18
prototorch/datasets/xor.py
Normal file
18
prototorch/datasets/xor.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_xor(num_samples=500):
|
||||||
|
x = torch.rand(num_samples, 2)
|
||||||
|
y = torch.zeros(num_samples)
|
||||||
|
y[torch.logical_and(x[:, 0] > 0.5, x[:, 1] < 0.5)] = 1
|
||||||
|
y[torch.logical_and(x[:, 1] > 0.5, x[:, 0] < 0.5)] = 1
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class XOR(torch.utils.data.TensorDataset):
|
||||||
|
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||||
|
def __init__(self, num_samples: int = 500):
|
||||||
|
x, y = make_xor(num_samples)
|
||||||
|
super().__init__(x, y)
|
@@ -1,7 +1,12 @@
|
|||||||
"""ProtoFlow utilities"""
|
"""ProtoFlow utilities"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Union
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -23,15 +28,15 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
|||||||
return mesh, xx, yy
|
return mesh, xx, yy
|
||||||
|
|
||||||
|
|
||||||
def distribution_from_list(list_dist: list[int], clabels: list[int] = []):
|
def distribution_from_list(list_dist: List[int],
|
||||||
|
clabels: Iterable[int] = None):
|
||||||
clabels = clabels or list(range(len(list_dist)))
|
clabels = clabels or list(range(len(list_dist)))
|
||||||
distribution = dict(zip(clabels, list_dist))
|
distribution = dict(zip(clabels, list_dist))
|
||||||
return distribution
|
return distribution
|
||||||
|
|
||||||
|
|
||||||
def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
|
def parse_distribution(user_distribution,
|
||||||
list[int], tuple[int]],
|
clabels: Iterable[int] = None) -> Dict[int, int]:
|
||||||
clabels: list[int] = []) -> dict[int, int]:
|
|
||||||
"""Parse user-provided distribution.
|
"""Parse user-provided distribution.
|
||||||
|
|
||||||
Return a dictionary with integer keys that represent the class labels and
|
Return a dictionary with integer keys that represent the class labels and
|
||||||
@@ -75,9 +80,13 @@ def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
|
|||||||
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
||||||
"""Return data and target as torch tensors."""
|
"""Return data and target as torch tensors."""
|
||||||
if isinstance(data_arg, Dataset):
|
if isinstance(data_arg, Dataset):
|
||||||
ds_size = len(data_arg)
|
if hasattr(data_arg, "__len__"):
|
||||||
loader = DataLoader(data_arg, batch_size=ds_size)
|
ds_size = len(data_arg) # type: ignore
|
||||||
data, targets = next(iter(loader))
|
loader = DataLoader(data_arg, batch_size=ds_size)
|
||||||
|
data, targets = next(iter(loader))
|
||||||
|
else:
|
||||||
|
emsg = f"Dataset {data_arg} is not sized (`__len__` unimplemented)."
|
||||||
|
raise TypeError(emsg)
|
||||||
|
|
||||||
elif isinstance(data_arg, DataLoader):
|
elif isinstance(data_arg, DataLoader):
|
||||||
data = torch.tensor([])
|
data = torch.tensor([])
|
||||||
|
45
setup.py
45
setup.py
@@ -1,10 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
_____ _ _______ _
|
|
||||||
| __ \ | | |__ __| | |
|
######
|
||||||
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
|
# # ##### #### ##### #### ##### #### ##### #### # #
|
||||||
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
|
# # # # # # # # # # # # # # # # # #
|
||||||
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|
###### # # # # # # # # # # # # # ######
|
||||||
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|
|
# ##### # # # # # # # # ##### # # #
|
||||||
|
# # # # # # # # # # # # # # # # #
|
||||||
|
# # # #### # #### # #### # # #### # #
|
||||||
|
|
||||||
ProtoTorch Core Package
|
ProtoTorch Core Package
|
||||||
"""
|
"""
|
||||||
@@ -18,7 +20,7 @@ with open("README.md", "r") as fh:
|
|||||||
|
|
||||||
INSTALL_REQUIRES = [
|
INSTALL_REQUIRES = [
|
||||||
"torch>=1.3.1",
|
"torch>=1.3.1",
|
||||||
"torchvision>=0.5.0",
|
"torchvision>=0.7.1",
|
||||||
"numpy>=1.9.1",
|
"numpy>=1.9.1",
|
||||||
"sklearn",
|
"sklearn",
|
||||||
]
|
]
|
||||||
@@ -26,7 +28,10 @@ DATASETS = [
|
|||||||
"requests",
|
"requests",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
]
|
]
|
||||||
DEV = ["bumpversion"]
|
DEV = [
|
||||||
|
"bumpversion",
|
||||||
|
"pre-commit",
|
||||||
|
]
|
||||||
DOCS = [
|
DOCS = [
|
||||||
"recommonmark",
|
"recommonmark",
|
||||||
"sphinx",
|
"sphinx",
|
||||||
@@ -38,12 +43,15 @@ EXAMPLES = [
|
|||||||
"matplotlib",
|
"matplotlib",
|
||||||
"torchinfo",
|
"torchinfo",
|
||||||
]
|
]
|
||||||
TESTS = ["codecov", "pytest"]
|
TESTS = [
|
||||||
|
"pytest-cov",
|
||||||
|
"pytest",
|
||||||
|
]
|
||||||
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="prototorch",
|
name="prototorch",
|
||||||
version="0.5.0",
|
version="0.7.1",
|
||||||
description="Highly extensible, GPU-supported "
|
description="Highly extensible, GPU-supported "
|
||||||
"Learning Vector Quantization (LVQ) toolbox "
|
"Learning Vector Quantization (LVQ) toolbox "
|
||||||
"built using PyTorch and its nn API.",
|
"built using PyTorch and its nn API.",
|
||||||
@@ -54,30 +62,33 @@ setup(
|
|||||||
url=PROJECT_URL,
|
url=PROJECT_URL,
|
||||||
download_url=DOWNLOAD_URL,
|
download_url=DOWNLOAD_URL,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
|
python_requires=">=3.6",
|
||||||
install_requires=INSTALL_REQUIRES,
|
install_requires=INSTALL_REQUIRES,
|
||||||
extras_require={
|
extras_require={
|
||||||
"docs": DOCS,
|
|
||||||
"datasets": DATASETS,
|
"datasets": DATASETS,
|
||||||
|
"dev": DEV,
|
||||||
|
"docs": DOCS,
|
||||||
"examples": EXAMPLES,
|
"examples": EXAMPLES,
|
||||||
"tests": TESTS,
|
"tests": TESTS,
|
||||||
"all": ALL,
|
"all": ALL,
|
||||||
},
|
},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 2 - Pre-Alpha",
|
|
||||||
"Environment :: Console",
|
"Environment :: Console",
|
||||||
|
"Natural Language :: English",
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Intended Audience :: Education",
|
"Intended Audience :: Education",
|
||||||
"Intended Audience :: Science/Research",
|
"Intended Audience :: Science/Research",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Software Development :: Libraries",
|
||||||
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Natural Language :: English",
|
"Operating System :: OS Independent",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.6",
|
"Programming Language :: Python :: 3.6",
|
||||||
"Programming Language :: Python :: 3.7",
|
"Programming Language :: Python :: 3.7",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Operating System :: OS Independent",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
||||||
"Topic :: Software Development :: Libraries",
|
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
||||||
],
|
],
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
|
@@ -265,13 +265,25 @@ def test_eye_transform_init_wide():
|
|||||||
|
|
||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
def test_linear_transform():
|
def test_linear_transform_default_eye_init():
|
||||||
l = pt.transforms.LinearTransform(2, 4)
|
l = pt.transforms.LinearTransform(2, 4)
|
||||||
actual = l.weights
|
actual = l.weights
|
||||||
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||||
assert torch.allclose(actual, desired)
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_transform_forward():
|
||||||
|
l = pt.transforms.LinearTransform(4, 2)
|
||||||
|
actual_weights = l.weights
|
||||||
|
desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]])
|
||||||
|
assert torch.allclose(actual_weights, desired_weights)
|
||||||
|
actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \
|
||||||
|
[1.1, 2.2, 3.3, 4.4], \
|
||||||
|
[5.5, 6.6, 7.7, 8.8]]))
|
||||||
|
desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]])
|
||||||
|
assert torch.allclose(actual_outputs, desired_outputs)
|
||||||
|
|
||||||
|
|
||||||
def test_linear_transform_zeros_init():
|
def test_linear_transform_zeros_init():
|
||||||
l = pt.transforms.LinearTransform(
|
l = pt.transforms.LinearTransform(
|
||||||
in_dim=2,
|
in_dim=2,
|
||||||
|
@@ -49,6 +49,23 @@ class TestNumpyDataset(unittest.TestCase):
|
|||||||
self.assertEqual(len(ds), 3)
|
self.assertEqual(len(ds), 3)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCSVDataset(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
data = np.random.rand(100, 4)
|
||||||
|
targets = np.random.randint(2, size=(100, 1))
|
||||||
|
arr = np.hstack([data, targets])
|
||||||
|
if not os.path.exists("./artifacts"):
|
||||||
|
os.mkdir("./artifacts")
|
||||||
|
np.savetxt("./artifacts/test.csv", arr, delimiter=",")
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
ds = pt.datasets.CSVDataset("./artifacts/test.csv")
|
||||||
|
self.assertEqual(len(ds), 100)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
os.remove("./artifacts/test.csv")
|
||||||
|
|
||||||
|
|
||||||
class TestSpiral(unittest.TestCase):
|
class TestSpiral(unittest.TestCase):
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
ds = pt.datasets.Spiral(num_samples=10)
|
ds = pt.datasets.Spiral(num_samples=10)
|
||||||
@@ -94,67 +111,67 @@ class TestMoons(unittest.TestCase):
|
|||||||
self.assertEqual(len(ds), 10)
|
self.assertEqual(len(ds), 10)
|
||||||
|
|
||||||
|
|
||||||
class TestTecator(unittest.TestCase):
|
# class TestTecator(unittest.TestCase):
|
||||||
def setUp(self):
|
# def setUp(self):
|
||||||
self.artifacts_dir = "./artifacts/Tecator"
|
# self.artifacts_dir = "./artifacts/Tecator"
|
||||||
self._remove_artifacts()
|
# self._remove_artifacts()
|
||||||
|
|
||||||
def _remove_artifacts(self):
|
# def _remove_artifacts(self):
|
||||||
if os.path.exists(self.artifacts_dir):
|
# if os.path.exists(self.artifacts_dir):
|
||||||
shutil.rmtree(self.artifacts_dir)
|
# shutil.rmtree(self.artifacts_dir)
|
||||||
|
|
||||||
def test_download_false(self):
|
# def test_download_false(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
# rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
self._remove_artifacts()
|
# self._remove_artifacts()
|
||||||
with self.assertRaises(RuntimeError):
|
# with self.assertRaises(RuntimeError):
|
||||||
_ = pt.datasets.Tecator(rootdir, download=False)
|
# _ = pt.datasets.Tecator(rootdir, download=False)
|
||||||
|
|
||||||
def test_download_caching(self):
|
# def test_download_caching(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
# rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
_ = pt.datasets.Tecator(rootdir, download=True, verbose=False)
|
# _ = pt.datasets.Tecator(rootdir, download=True, verbose=False)
|
||||||
_ = pt.datasets.Tecator(rootdir, download=False, verbose=False)
|
# _ = pt.datasets.Tecator(rootdir, download=False, verbose=False)
|
||||||
|
|
||||||
def test_repr(self):
|
# def test_repr(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
# rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
train = pt.datasets.Tecator(rootdir, download=True, verbose=True)
|
# train = pt.datasets.Tecator(rootdir, download=True, verbose=True)
|
||||||
self.assertTrue("Split: Train" in train.__repr__())
|
# self.assertTrue("Split: Train" in train.__repr__())
|
||||||
|
|
||||||
def test_download_train(self):
|
# def test_download_train(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
# rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
train = pt.datasets.Tecator(root=rootdir,
|
# train = pt.datasets.Tecator(root=rootdir,
|
||||||
train=True,
|
# train=True,
|
||||||
download=True,
|
# download=True,
|
||||||
verbose=False)
|
# verbose=False)
|
||||||
train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False)
|
# train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False)
|
||||||
x_train, y_train = train.data, train.targets
|
# x_train, y_train = train.data, train.targets
|
||||||
self.assertEqual(x_train.shape[0], 144)
|
# self.assertEqual(x_train.shape[0], 144)
|
||||||
self.assertEqual(y_train.shape[0], 144)
|
# self.assertEqual(y_train.shape[0], 144)
|
||||||
self.assertEqual(x_train.shape[1], 100)
|
# self.assertEqual(x_train.shape[1], 100)
|
||||||
|
|
||||||
def test_download_test(self):
|
# def test_download_test(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
# rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
# test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
x_test, y_test = test.data, test.targets
|
# x_test, y_test = test.data, test.targets
|
||||||
self.assertEqual(x_test.shape[0], 71)
|
# self.assertEqual(x_test.shape[0], 71)
|
||||||
self.assertEqual(y_test.shape[0], 71)
|
# self.assertEqual(y_test.shape[0], 71)
|
||||||
self.assertEqual(x_test.shape[1], 100)
|
# self.assertEqual(x_test.shape[1], 100)
|
||||||
|
|
||||||
def test_class_to_idx(self):
|
# def test_class_to_idx(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
# rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
# test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
_ = test.class_to_idx
|
# _ = test.class_to_idx
|
||||||
|
|
||||||
def test_getitem(self):
|
# def test_getitem(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
# rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
# test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
x, y = test[0]
|
# x, y = test[0]
|
||||||
self.assertEqual(x.shape[0], 100)
|
# self.assertEqual(x.shape[0], 100)
|
||||||
self.assertIsInstance(y, int)
|
# self.assertIsInstance(y, int)
|
||||||
|
|
||||||
def test_loadable_with_dataloader(self):
|
# def test_loadable_with_dataloader(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
# rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
# test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
# _ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
||||||
|
|
||||||
def tearDown(self):
|
# def tearDown(self):
|
||||||
self._remove_artifacts()
|
# self._remove_artifacts()
|
||||||
|
Reference in New Issue
Block a user