50 Commits

Author SHA1 Message Date
Alexander Engelsberger
0788718c31 ci: cache test IV 2021-11-05 15:14:59 +01:00
Alexander Engelsberger
4f5c4ebe8f ci: cache test III 2021-11-05 15:09:21 +01:00
Alexander Engelsberger
ae2a8e54ef ci: cache test II 2021-11-05 14:59:01 +01:00
Alexander Engelsberger
d9be100c1f ci: use pip cache in jenkins 2021-11-05 14:55:12 +01:00
Alexander Engelsberger
9d1dc7320f ci: fix jenkinsfile 2021-11-05 14:32:37 +01:00
Alexander Engelsberger
d11ab71b7e ci: unit tests in jenkins 2021-11-05 14:30:08 +01:00
Alexander Engelsberger
59037e1a50 ci: upgrade pip before install 2021-11-04 10:53:51 +01:00
Alexander Engelsberger
a19b99be82 ci: container debugging III 2021-11-04 10:50:44 +01:00
Alexander Engelsberger
f7e7558338 ci: container debugging II 2021-11-04 10:42:53 +01:00
Alexander Engelsberger
d57648f9d6 ci: container debugging 2021-11-04 10:41:28 +01:00
Alexander Engelsberger
d24f580bf0 ci: install dependencies with user flag 2021-11-04 09:55:58 +01:00
Jensun Ravichandran
916973c3e8 ci: migrate to jenkins 2021-11-03 16:26:32 +01:00
Alexander Engelsberger
b49b7a2d41 build: bump version 0.7.0 → 0.7.1 2021-08-30 17:55:48 +02:00
Alexander Engelsberger
b6e8242383 ci: add build phase for tags 2021-08-30 17:55:32 +02:00
Alexander Engelsberger
cd616d11b9 build: bump version 0.6.0 → 0.7.0 2021-08-30 17:42:27 +02:00
Alexander Engelsberger
afcfcb8973 fix: setup.py tags 2021-08-30 17:42:22 +02:00
Alexander Engelsberger
bf03a45475 feat(compatibility): Python3.6 compatibility 2021-08-30 17:39:10 +02:00
Alexander Engelsberger
083b5c1597 feat(compatibility): Python3.7 compatibility 2021-08-30 17:39:10 +02:00
Alexander Engelsberger
7f0a8e9bce feat(compatibility): Python3.8 compatibility 2021-08-30 17:39:10 +02:00
Jensun Ravichandran
bf09ff8f7f feat: add XOR dataset 2021-07-15 18:14:38 +02:00
Jensun Ravichandran
c1d7cfee8f fix(test): fix broken CSVDataset test 2021-07-06 17:07:26 +02:00
Jensun Ravichandran
99be965581 refactor: refactor GLVQLoss 2021-07-06 17:01:28 +02:00
Jensun Ravichandran
fdb9a7c66d feat: add CSVDataset 2021-07-04 16:30:01 +02:00
Jensun Ravichandran
eb79b703d8 chore(github): update bug report issue template 2021-06-22 15:06:18 +02:00
Jensun Ravichandran
bc9a826b7d fix: matmul bug in 2021-06-21 22:48:22 +02:00
Alexander Engelsberger
cfe09ec06b fix: reasonings init parameters are used now 2021-06-21 14:53:22 +02:00
Alexander Engelsberger
3d76dffe3c chore: Allow no-self-use for some class members
Classes are used as common interface and connection to pytorch.
2021-06-21 14:29:25 +02:00
Jensun Ravichandran
597c9fc1ee build: bump version 0.5.1 → 0.6.0 2021-06-20 19:12:01 +02:00
Jensun Ravichandran
a8c74a1a6f chore(bumpversion): modify bump message 2021-06-20 19:09:35 +02:00
Jensun Ravichandran
f78ff1a464 fix(initializers): bug fixes in LT initializers 2021-06-20 18:56:06 +02:00
Jensun Ravichandran
5a3dbfac2e chore(pre-commit): prettify .pre-commit-config.yaml 2021-06-20 18:54:37 +02:00
Jensun Ravichandran
478a3c2cfe fix: python is python3.9 2021-06-20 17:49:53 +02:00
Jensun Ravichandran
4520fdde8e chore(travis): point build badge to travis-ci.com 2021-06-18 19:27:28 +02:00
Jensun Ravichandran
b90044b86c fix: python is python3.9 2021-06-18 19:20:54 +02:00
Jensun Ravichandran
a1310df4ee test(datasets): turn off tecator tests temporarily 2021-06-18 19:10:29 +02:00
Jensun Ravichandran
5dc66494ea refactor(api)!: merge the new api changes into dev
BREAKING CHANGE: remove the following
`prototorch/functions/*`
`prototorch/components/*`
`prototorch/modules/*`
BREAKING CHANGE: move `initializers` into the `prototorch.initializers`
namespace from the `prototorch.components` namespace
BREAKING CHANGE: `functions` and `modules` and moved into `core` and `nn`
2021-06-18 18:54:55 +02:00
Jensun Ravichandran
74d420a77d refactor(api)!: merge the new api changes into dev
BREAKING CHANGE: remove the following
`prototorch/functions/*`
`prototorch/components/*`
`prototorch/modules/*`
BREAKING CHANGE: move `initializers` into the `prototorch.initializers`
namespace from the `prototorch.components` namespace
BREAKING CHANGE: `functions` and `modules` and moved into `core` and `nn`
2021-06-18 18:20:30 +02:00
Jensun Ravichandran
6ffd14e85c Bump version: 0.5.0 → 0.5.1 2021-06-18 15:49:20 +02:00
Jensun Ravichandran
40c1021c20 Remove examples 2021-06-18 13:41:03 +02:00
Jensun Ravichandran
acf3272fd7 Remove .swp files 2021-06-18 13:39:43 +02:00
danielstaps
c73f8e7a28 Added PCA initializer and component for OmegaMatrix or LinearMappings (#6)
* Added PCA initializer and component for OmegaMatrix or LinearMappings

* [QA] Add default configuration for pre commit hooks

* [QA] Add more pre commit checks

* [QA] Add more pre commit checks

* test(githooks): Add gitlint to check commit messages on commit

* docs(githooks): Add usage guide for pre-commit  to readme

* fix(githooks): mypy only checks source now

reverts changes on docs conf.py

* docs(githooks): Fix typo

Co-authored-by: staps@hs-mittweida.de <staps@hs-mittweida.de>
Co-authored-by: Alexander Engelsberger <alexanderengelsberger@gmail.com>
2021-06-18 13:28:25 +02:00
Alexander Engelsberger
bf23d5f7f8 docs(githooks): Fix typo 2021-06-16 15:23:23 +02:00
Alexander Engelsberger
bcde3f6ac8 fix(githooks): mypy only checks source now
reverts changes on docs conf.py
2021-06-16 15:23:23 +02:00
Alexander Engelsberger
d5229b1750 docs(githooks): Add usage guide for pre-commit to readme 2021-06-16 15:23:23 +02:00
Alexander Engelsberger
fc4b143fbb test(githooks): Add gitlint to check commit messages on commit 2021-06-16 15:23:23 +02:00
Alexander Engelsberger
11cfa79746 [QA] Add more pre commit checks 2021-06-16 15:23:23 +02:00
Alexander Engelsberger
d0ae94f2af [QA] Add more pre commit checks 2021-06-16 15:23:23 +02:00
Alexander Engelsberger
2c908a8361 [QA] Add default configuration for pre commit hooks 2021-06-16 15:23:23 +02:00
Alexander Engelsberger
e4257ec1f1 Merge branch 'dev' of github.com:si-cim/prototorch into dev 2021-06-11 16:10:04 +02:00
Alexander Engelsberger
aaad2b8626 [BUGFIX] Fix labeled components if initialized 2021-06-11 16:09:51 +02:00
35 changed files with 564 additions and 728 deletions

View File

@@ -1,10 +1,10 @@
[bumpversion] [bumpversion]
current_version = 0.5.0 current_version = 0.7.1
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = serialize = {major}.{minor}.{patch}
{major}.{minor}.{patch} message = build: bump version {current_version} → {new_version}
[bumpversion:file:setup.py] [bumpversion:file:setup.py]

7
.ci/python310.Dockerfile Normal file
View File

@@ -0,0 +1,7 @@
FROM python:3.9
RUN adduser --uid 1000 jenkins
USER jenkins
RUN mkdir -p /home/jenkins/.cache/pip

7
.ci/python36.Dockerfile Normal file
View File

@@ -0,0 +1,7 @@
FROM python:3.6
RUN adduser --uid 1000 jenkins
USER jenkins
RUN mkdir -p /home/jenkins/.cache/pip

View File

@@ -10,21 +10,28 @@ assignees: ''
**Describe the bug** **Describe the bug**
A clear and concise description of what the bug is. A clear and concise description of what the bug is.
**To Reproduce** **Steps to reproduce the behavior**
Steps to reproduce the behavior: 1. ...
1. Install Prototorch by running '...' 2. Run script '...' or this snippet:
2. Run script '...' ```python
import prototorch as pt
...
```
3. See errors 3. See errors
**Expected behavior** **Expected behavior**
A clear and concise description of what you expected to happen. A clear and concise description of what you expected to happen.
**Observed behavior**
A clear and concise description of what actually happened.
**Screenshots** **Screenshots**
If applicable, add screenshots to help explain your problem. If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):** **System and version information**
- OS: [e.g. Ubuntu 20.10] - OS: [e.g. Ubuntu 20.10]
- Prototorch Version: [e.g. v0.4.0] - ProtoTorch Version: [e.g. 0.4.0]
- Python Version: [e.g. 3.9.5] - Python Version: [e.g. 3.9.5]
**Additional context** **Additional context**

View File

@@ -16,10 +16,10 @@ jobs:
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python 3.8 - name: Set up Python 3.9
uses: actions/setup-python@v1 uses: actions/setup-python@v1
with: with:
python-version: 3.8 python-version: 3.9
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip

2
.gitignore vendored
View File

@@ -155,4 +155,4 @@ dmypy.json
reports reports
artifacts artifacts
examples/_*.py examples/_*.py
examples/_*.ipynb examples/_*.ipynb

53
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,53 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-ast
- id: check-case-conflict
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
- repo: http://github.com/PyCQA/isort
rev: 5.8.0
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.902
hooks:
- id: mypy
files: prototorch
additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.31.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0
hooks:
- id: python-use-type-annotations
- id: python-no-log-warn
- id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v2.19.4
hooks:
- id: pyupgrade
- repo: https://github.com/si-cim/gitlint
rev: v0.15.2-unofficial
hooks:
- id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]

View File

@@ -19,7 +19,7 @@ formats: all
# Optionally set the version of Python and requirements required to build your docs # Optionally set the version of Python and requirements required to build your docs
python: python:
version: 3.8 version: 3.9
install: install:
- method: pip - method: pip
path: . path: .

View File

@@ -1,36 +0,0 @@
dist: bionic
sudo: false
language: python
python: 3.9
cache:
directories:
- "$HOME/.cache/pip"
- "./tests/artifacts"
- "$HOME/datasets"
install:
- pip install .[all] --progress-bar off
# Generate code coverage report
script:
- coverage run -m pytest
# Push the results to codecov
after_success:
- bash <(curl -s https://codecov.io/bash)
# Publish on PyPI
deploy:
provider: pypi
username: __token__
password:
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
on:
tags: true
skip_existing: true
# The password is encrypted with:
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
# See https://docs.travis-ci.com/user/deployment/pypi and
# https://github.com/travis-ci/travis.rb#installation
# for more details
# Note: The encrypt command does not work well in ZSH.

41
Jenkinsfile vendored Normal file
View File

@@ -0,0 +1,41 @@
pipeline {
agent none
stages {
stage('Unit Tests') {
parallel {
stage('3.6'){
agent{
dockerfile {
filename 'python36.Dockerfile'
dir '.ci'
args '-v pip-cache:/home/jenkins/.cache/pip'
}
}
steps {
sh 'pip install pip --upgrade --progress-bar off'
sh 'pip install .[all] --progress-bar off'
sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
cobertura coberturaReportFile: 'reports/coverage.xml'
junit 'reports/**/*.xml'
}
}
stage('3.10'){
agent{
dockerfile {
filename 'python310.Dockerfile'
dir '.ci'
args '-v pip-cache:/home/jenkins/.cache/pip'
}
}
steps {
sh 'pip install pip --upgrade --progress-bar off'
sh 'pip install .[all] --progress-bar off'
sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
cobertura coberturaReportFile: 'reports/coverage.xml'
junit 'reports/**/*.xml'
}
}
}
}
}
}

View File

@@ -2,13 +2,12 @@
![ProtoTorch Logo](https://prototorch.readthedocs.io/en/latest/_static/horizontal-lockup.png) ![ProtoTorch Logo](https://prototorch.readthedocs.io/en/latest/_static/horizontal-lockup.png)
[![Build Status](https://travis-ci.org/si-cim/prototorch.svg?branch=master)](https://travis-ci.org/si-cim/prototorch) [![Build Status](https://api.travis-ci.com/si-cim/prototorch.svg?branch=master)](https://travis-ci.com/github/si-cim/prototorch)
![tests](https://github.com/si-cim/prototorch/workflows/tests/badge.svg) ![tests](https://github.com/si-cim/prototorch/workflows/tests/badge.svg)
[![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch?color=yellow&label=version)](https://github.com/si-cim/prototorch/releases) [![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch?color=yellow&label=version)](https://github.com/si-cim/prototorch/releases)
[![PyPI](https://img.shields.io/pypi/v/prototorch)](https://pypi.org/project/prototorch/) [![PyPI](https://img.shields.io/pypi/v/prototorch)](https://pypi.org/project/prototorch/)
[![codecov](https://codecov.io/gh/si-cim/prototorch/branch/master/graph/badge.svg)](https://codecov.io/gh/si-cim/prototorch) [![codecov](https://codecov.io/gh/si-cim/prototorch/branch/master/graph/badge.svg)](https://codecov.io/gh/si-cim/prototorch)
[![Codacy Badge](https://api.codacy.com/project/badge/Grade/76273904bf9343f0a8b29cd8aca242e7)](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=si-cim/prototorch&amp;utm_campaign=Badge_Grade) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/76273904bf9343f0a8b29cd8aca242e7)](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=si-cim/prototorch&amp;utm_campaign=Badge_Grade)
![PyPI - Downloads](https://img.shields.io/pypi/dm/prototorch?color=blue)
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch)](https://github.com/si-cim/prototorch/blob/master/LICENSE) [![GitHub license](https://img.shields.io/github/license/si-cim/prototorch)](https://github.com/si-cim/prototorch/blob/master/LICENSE)
*Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow) *Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow)
@@ -48,6 +47,23 @@ pip install -e .[all]
The documentation is available at <https://www.prototorch.ml/en/latest/>. Should The documentation is available at <https://www.prototorch.ml/en/latest/>. Should
that link not work try <https://prototorch.readthedocs.io/en/latest/>. that link not work try <https://prototorch.readthedocs.io/en/latest/>.
## Contribution
This repository contains definition for [git hooks](https://githooks.com).
[Pre-commit](https://pre-commit.com) is automatically installed as development
dependency with prototorch or you can install it manually with `pip install
pre-commit`.
Please install the hooks by running:
```bash
pre-commit install
pre-commit install --hook-type commit-msg
```
before creating the first commit.
The commit will fail if the commit message does not follow the specification
provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
## Bibtex ## Bibtex
If you would like to cite the package, please use this: If you would like to cite the package, please use this:

46
deprecated.travis.yml Normal file
View File

@@ -0,0 +1,46 @@
dist: bionic
sudo: false
language: python
python:
- 3.9
- 3.8
- 3.7
- 3.6
cache:
directories:
- "$HOME/.cache/pip"
- "./tests/artifacts"
- "$HOME/datasets"
install:
- pip install .[all] --progress-bar off
# Generate code coverage report
script:
- coverage run -m pytest
# Push the results to codecov
after_success:
- bash <(curl -s https://codecov.io/bash)
# Publish on PyPI
jobs:
include:
- stage: build
python: 3.9
script: echo "Starting Pypi build"
deploy:
provider: pypi
username: __token__
distributions: "sdist bdist_wheel"
password:
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
on:
tags: true
skip_existing: true
# The password is encrypted with:
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
# See https://docs.travis-ci.com/user/deployment/pypi and
# https://github.com/travis-ci/travis.rb#installation
# for more details
# Note: The encrypt command does not work well in ZSH.

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.5.0" release = "0.7.1"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

96
examples/cbc_iris.py Normal file
View File

@@ -0,0 +1,96 @@
"""ProtoTorch CBC example using 2D Iris data."""
import torch
from matplotlib import pyplot as plt
import prototorch as pt
class CBC(torch.nn.Module):
def __init__(self, data, **kwargs):
super().__init__(**kwargs)
self.components_layer = pt.components.ReasoningComponents(
distribution=[2, 1, 2],
components_initializer=pt.initializers.SSCI(data, noise=0.1),
reasonings_initializer=pt.initializers.PPRI(components_first=True),
)
def forward(self, x):
components, reasonings = self.components_layer()
sims = pt.similarities.euclidean_similarity(x, components)
probs = pt.competitions.cbcc(sims, reasonings)
return probs
class VisCBC2D():
def __init__(self, model, data):
self.model = model
self.x_train, self.y_train = pt.utils.parse_data_arg(data)
self.title = "Components Visualization"
self.fig = plt.figure(self.title)
self.border = 0.1
self.resolution = 100
self.cmap = "viridis"
def on_epoch_end(self):
x_train, y_train = self.x_train, self.y_train
_components = self.model.components_layer._components.detach()
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.scatter(
x_train[:, 0],
x_train[:, 1],
c=y_train,
cmap=self.cmap,
edgecolor="k",
marker="o",
s=30,
)
ax.scatter(
_components[:, 0],
_components[:, 1],
c="w",
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = torch.vstack((x_train, _components))
mesh_input, xx, yy = pt.utils.mesh2d(x, self.border, self.resolution)
with torch.no_grad():
y_pred = self.model(
torch.Tensor(mesh_input).type_as(_components)).argmax(1)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
plt.pause(0.2)
if __name__ == "__main__":
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
model = CBC(train_ds)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = pt.losses.MarginLoss(margin=0.1)
vis = VisCBC2D(model, train_ds)
for epoch in range(200):
correct = 0.0
for x, y in train_loader:
y_oh = torch.eye(3)[y]
y_pred = model(x)
loss = criterion(y_pred, y_oh).mean(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()
correct += (y_pred.argmax(1) == y).float().sum(0)
acc = 100 * correct / len(train_ds)
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
vis.on_epoch_end()

View File

@@ -1,120 +0,0 @@
"""ProtoTorch GLVQ example using 2D Iris data."""
import numpy as np
import torch
from matplotlib import pyplot as plt
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
# Prepare and preprocess the data
scaler = StandardScaler()
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
scaler.fit(x_train)
x_train = scaler.transform(x_train)
# Define the GLVQ model
class Model(torch.nn.Module):
def __init__(self):
"""GLVQ model for training on 2D Iris data."""
super().__init__()
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
self.proto_layer = LabeledComponents(
prototype_distribution,
prototype_initializer,
)
def forward(self, x):
prototypes, prototype_labels = self.proto_layer()
distances = euclidean_distance(x, prototypes)
return distances, prototype_labels
# Build the GLVQ model
model = Model()
# Print summary using torchinfo (might be buggy/incorrect)
print(summary(model))
# Optimize using SGD optimizer from `torch.optim`
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train)
# Training loop
TITLE = "Prototype Visualization"
fig = plt.figure(TITLE)
for epoch in range(70):
# Compute loss
distances, prototype_labels = model(x_in)
loss = criterion([distances, prototype_labels], y_in)
# Compute Accuracy
with torch.no_grad():
predictions = wtac(distances, prototype_labels)
correct = predictions.eq(y_in.view_as(predictions)).sum().item()
acc = 100.0 * correct / len(x_train)
print(
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
)
# Optimizer step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Get the prototypes form the model
prototypes = model.proto_layer.components.numpy()
if np.isnan(np.sum(prototypes)):
print("Stopping training because of `nan` in prototypes.")
break
# Visualize the data and the prototypes
ax = fig.gca()
ax.cla()
ax.set_title(TITLE)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
prototypes[:, 0],
prototypes[:, 1],
c=prototype_labels,
cmap=cmap,
edgecolor="k",
marker="D",
s=50,
)
# Paint decision regions
x = np.vstack((x_train, prototypes))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
torch_input = torch.Tensor(mesh_input)
d = model(torch_input)[0]
w_indices = torch.argmin(d, dim=1)
y_pred = torch.index_select(prototype_labels, 0, w_indices)
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)

View File

@@ -1,103 +0,0 @@
"""ProtoTorch "siamese" GMLVQ example using Tecator."""
import matplotlib.pyplot as plt
import torch
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed
from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import get_legend_handles
from torch.utils.data import DataLoader
# Prepare the dataset and dataloader
train_data = Tecator(root="./artifacts", train=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
class Model(torch.nn.Module):
def __init__(self, **kwargs):
"""GMLVQ model as a siamese network."""
super().__init__()
prototype_initializer = StratifiedMeanInitializer(train_loader)
prototype_distribution = {"num_classes": 2, "prototypes_per_class": 2}
self.proto_layer = LabeledComponents(
prototype_distribution,
prototype_initializer,
)
self.omega = torch.nn.Linear(in_features=100,
out_features=100,
bias=False)
torch.nn.init.eye_(self.omega.weight)
def forward(self, x):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
# Process `x` and `protos` through `omega`
x_map = self.omega(x)
protos_map = self.omega(protos)
# Compute distances and output
dis = sed(x_map, protos_map)
return dis, plabels
# Build the GLVQ model
model = Model()
# Print a summary of the model
print(model)
# Optimize using Adam optimizer from `torch.optim`
optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1)
criterion = GLVQLoss(squashing="identity", beta=10)
# Training loop
for epoch in range(150):
epoch_loss = 0.0 # zero-out epoch loss
optimizer.zero_grad() # zero-out gradients
for xb, yb in train_loader:
# Compute loss
distances, plabels = model(xb)
loss = criterion([distances, plabels], yb)
epoch_loss += loss.item()
# Backprop
loss.backward()
# Take a gradient descent step
optimizer.step()
scheduler.step()
lr = optimizer.param_groups[0]["lr"]
print(f"Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}")
# Get the omega matrix form the model
omega = model.omega.weight.data.numpy().T
# Visualize the lambda matrix
title = "Lambda Matrix Visualization"
fig = plt.figure(title)
ax = fig.gca()
ax.set_title(title)
im = ax.imshow(omega.dot(omega.T), cmap="viridis")
plt.show()
# Get the prototypes form the model
protos = model.proto_layer.components.numpy()
plabels = model.proto_layer.component_labels.numpy()
# Visualize the prototypes
title = "Tecator Prototypes"
fig = plt.figure(title)
ax = fig.gca()
ax.set_title(title)
ax.set_xlabel("Spectral frequencies")
ax.set_ylabel("Absorption")
clabels = ["Class 0 - Low fat", "Class 1 - High fat"]
handles, colors = get_legend_handles(clabels, marker="line", zero_indexed=True)
for x, y in zip(protos, plabels):
ax.plot(x, c=colors[int(y)])
ax.legend(handles, clabels)
plt.show()

View File

@@ -1,183 +0,0 @@
"""
ProtoTorch GTLVQ example using MNIST data.
The GTLVQ is placed as an classification model on
top of a CNN, considered as featurer extractor.
Initialization of subpsace and prototypes in
Siamnese fashion
For more info about GTLVQ see:
DOI:10.1109/IJCNN.2016.7727534
"""
import numpy as np
import torch
import torch.nn as nn
import torchvision
from prototorch.functions.helper import calculate_prototype_accuracy
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.models import GTLVQ
from torchvision import transforms
# Parameters and options
num_epochs = 50
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.1
momentum = 0.5
log_interval = 10
cuda = "cuda:0"
random_seed = 1
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
# Configures reproducability
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# Prepare and preprocess the data
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
"./files/",
train=True,
download=True,
transform=torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]),
),
batch_size=batch_size_train,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
"./files/",
train=False,
download=True,
transform=torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]),
),
batch_size=batch_size_test,
shuffle=True,
)
# Define the GLVQ model plus appropriate feature extractor
class CNNGTLVQ(torch.nn.Module):
def __init__(
self,
num_classes,
subspace_data,
prototype_data,
tangent_projection_type="local",
prototypes_per_class=2,
bottleneck_dim=128,
):
super(CNNGTLVQ, self).__init__()
# Feature Extractor - Simple CNN
self.fe = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.25),
nn.Flatten(),
nn.Linear(9216, bottleneck_dim),
nn.Dropout(0.5),
nn.LeakyReLU(),
nn.LayerNorm(bottleneck_dim),
)
# Forward pass of subspace and prototype initialization data through feature extractor
subspace_data = self.fe(subspace_data)
prototype_data[0] = self.fe(prototype_data[0])
# Initialization of GTLVQ
self.gtlvq = GTLVQ(
num_classes,
subspace_data,
prototype_data,
tangent_projection_type=tangent_projection_type,
feature_dim=bottleneck_dim,
prototypes_per_class=prototypes_per_class,
)
def forward(self, x):
# Feature Extraction
x = self.fe(x)
# GTLVQ Forward pass
dis = self.gtlvq(x)
return dis
# Get init data
subspace_data = torch.cat(
[next(iter(train_loader))[0],
next(iter(test_loader))[0]])
prototype_data = next(iter(train_loader))
# Build the CNN GTLVQ model
model = CNNGTLVQ(
10,
subspace_data,
prototype_data,
tangent_projection_type="local",
bottleneck_dim=128,
).to(device)
# Optimize using SGD optimizer from `torch.optim`
optimizer = torch.optim.Adam(
[{
"params": model.fe.parameters()
}, {
"params": model.gtlvq.parameters()
}],
lr=learning_rate,
)
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
# Training loop
for epoch in range(num_epochs):
for batch_idx, (x_train, y_train) in enumerate(train_loader):
model.train()
x_train, y_train = x_train.to(device), y_train.to(device)
optimizer.zero_grad()
distances = model(x_train)
plabels = model.gtlvq.cls.component_labels.to(device)
# Compute loss.
loss = criterion([distances, plabels], y_train)
loss.backward()
optimizer.step()
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
model.gtlvq.orthogonalize_subspace()
if batch_idx % log_interval == 0:
acc = calculate_prototype_accuracy(distances, y_train, plabels)
print(
f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
Train Acc: {acc.item():02.02f}")
# Test
with torch.no_grad():
model.eval()
correct = 0
total = 0
for x_test, y_test in test_loader:
x_test, y_test = x_test.to(device), y_test.to(device)
test_distances = model(torch.tensor(x_test))
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
i = torch.argmin(test_distances, 1)
correct += torch.sum(y_test == test_plabels[i])
total += y_test.size(0)
print("Accuracy of the network on the test images: %d %%" %
(torch.true_divide(correct, total) * 100))
# Save the model
PATH = "./glvq_mnist_model.pth"
torch.save(model.state_dict(), PATH)

View File

@@ -1,108 +0,0 @@
"""ProtoTorch LGMLVQ example using 2D Iris data."""
import numpy as np
import torch
from matplotlib import pyplot as plt
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.competitions import stratified_min
from prototorch.functions.distances import lomega_distance
from prototorch.modules.losses import GLVQLoss
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
# Prepare training data
x_train, y_train = load_iris(True)
x_train = x_train[:, [0, 2]]
# Define the model
class Model(torch.nn.Module):
def __init__(self):
"""Local-GMLVQ model."""
super().__init__()
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
prototype_distribution = [1, 2, 2]
self.proto_layer = LabeledComponents(
prototype_distribution,
prototype_initializer,
)
omegas = torch.eye(2, 2).repeat(5, 1, 1)
self.omegas = torch.nn.Parameter(omegas)
def forward(self, x):
protos, plabels = self.proto_layer()
omegas = self.omegas
dis = lomega_distance(x, protos, omegas)
return dis, plabels
# Build the model
model = Model()
# Optimize using Adam optimizer from `torch.optim`
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train)
# Training loop
title = "Prototype Visualization"
fig = plt.figure(title)
for epoch in range(100):
# Compute loss
dis, plabels = model(x_in)
loss = criterion([dis, plabels], y_in)
y_pred = np.argmin(stratified_min(dis, plabels).detach().numpy(), axis=1)
acc = accuracy_score(y_train, y_pred)
log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} "
log_string += f"Acc: {acc * 100:05.02f}%"
print(log_string)
# Take a gradient descent step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Get the prototypes form the model
protos = model.proto_layer.components.numpy()
# Visualize the data and the prototypes
ax = fig.gca()
ax.cla()
ax.set_title(title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=cmap,
edgecolor="k",
marker="D",
s=50,
)
# Paint decision regions
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
d, plabels = model(torch.Tensor(mesh_input))
y_pred = np.argmin(stratified_min(d, plabels).detach().numpy(), axis=1)
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)

View File

@@ -1,6 +1,7 @@
"""ProtoTorch package""" """ProtoTorch package"""
import pkgutil import pkgutil
from typing import List
import pkg_resources import pkg_resources
@@ -21,7 +22,7 @@ from .core import (
) )
# Core Setup # Core Setup
__version__ = "0.5.0" __version__ = "0.7.1"
__all_core__ = [ __all_core__ = [
"competitions", "competitions",
@@ -39,7 +40,7 @@ __all_core__ = [
] ]
# Plugin Loader # Plugin Loader
__path__ = pkgutil.extend_path(__path__, __name__) __path__: List[str] = pkgutil.extend_path(__path__, __name__)
def discover_plugins(): def discover_plugins():

View File

@@ -3,8 +3,7 @@
import torch import torch
def wtac(distances: torch.Tensor, def wtac(distances: torch.Tensor, labels: torch.LongTensor):
labels: torch.LongTensor) -> (torch.LongTensor):
"""Winner-Takes-All-Competition. """Winner-Takes-All-Competition.
Returns the labels corresponding to the winners. Returns the labels corresponding to the winners.
@@ -15,9 +14,7 @@ def wtac(distances: torch.Tensor,
return winning_labels return winning_labels
def knnc(distances: torch.Tensor, def knnc(distances: torch.Tensor, labels: torch.LongTensor, k: int = 1):
labels: torch.LongTensor,
k: int = 1) -> (torch.LongTensor):
"""K-Nearest-Neighbors-Competition. """K-Nearest-Neighbors-Competition.
Returns the labels corresponding to the winners. Returns the labels corresponding to the winners.
@@ -51,7 +48,7 @@ class WTAC(torch.nn.Module):
Thin wrapper over the `wtac` function. Thin wrapper over the `wtac` function.
""" """
def forward(self, distances, labels): def forward(self, distances, labels): # pylint: disable=no-self-use
return wtac(distances, labels) return wtac(distances, labels)
@@ -61,7 +58,7 @@ class LTAC(torch.nn.Module):
Thin wrapper over the `wtac` function. Thin wrapper over the `wtac` function.
""" """
def forward(self, probs, labels): def forward(self, probs, labels): # pylint: disable=no-self-use
return wtac(-1.0 * probs, labels) return wtac(-1.0 * probs, labels)
@@ -88,5 +85,5 @@ class CBCC(torch.nn.Module):
Thin wrapper over the `cbcc` function. Thin wrapper over the `cbcc` function.
""" """
def forward(self, detections, reasonings): def forward(self, detections, reasonings): # pylint: disable=no-self-use
return cbcc(detections, reasonings) return cbcc(detections, reasonings)

View File

@@ -86,8 +86,8 @@ class AbstractComponents(torch.nn.Module):
class Components(AbstractComponents): class Components(AbstractComponents):
"""A set of adaptable Tensors.""" """A set of adaptable Tensors."""
def __init__(self, num_components: int, def __init__(self, num_components: int,
initializer: AbstractComponentsInitializer, **kwargs): initializer: AbstractComponentsInitializer):
super().__init__(**kwargs) super().__init__()
self.add_components(num_components, initializer) self.add_components(num_components, initializer)
def add_components(self, num_components: int, def add_components(self, num_components: int,
@@ -154,9 +154,8 @@ class Labels(AbstractLabels):
"""A set of standalone labels.""" """A set of standalone labels."""
def __init__(self, def __init__(self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],
initializer: AbstractLabelsInitializer = LabelsInitializer(), initializer: AbstractLabelsInitializer = LabelsInitializer()):
**kwargs): super().__init__()
super().__init__(**kwargs)
self.add_labels(distribution, initializer) self.add_labels(distribution, initializer)
def add_labels( def add_labels(
@@ -184,13 +183,11 @@ class Labels(AbstractLabels):
class LabeledComponents(AbstractComponents): class LabeledComponents(AbstractComponents):
"""A set of adaptable components and corresponding unadaptable labels.""" """A set of adaptable components and corresponding unadaptable labels."""
def __init__( def __init__(
self, self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer, components_initializer: AbstractComponentsInitializer,
labels_initializer: AbstractLabelsInitializer = LabelsInitializer( labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
), super().__init__()
**kwargs):
super().__init__(**kwargs)
self.add_components(distribution, components_initializer, self.add_components(distribution, components_initializer,
labels_initializer) labels_initializer)
@@ -252,12 +249,14 @@ class Reasonings(torch.nn.Module):
The `reasonings` tensor is of shape [num_components, num_classes, 2]. The `reasonings` tensor is of shape [num_components, num_classes, 2].
""" """
def __init__(self, def __init__(
distribution: Union[dict, list, tuple], self,
initializer: distribution: Union[dict, list, tuple],
AbstractReasoningsInitializer = RandomReasoningsInitializer(), initializer:
**kwargs): AbstractReasoningsInitializer = RandomReasoningsInitializer(),
super().__init__(**kwargs) ):
super().__init__()
self.add_reasonings(distribution, initializer)
@property @property
def num_classes(self): def num_classes(self):
@@ -295,7 +294,7 @@ class Reasonings(torch.nn.Module):
class ReasoningComponents(AbstractComponents): class ReasoningComponents(AbstractComponents):
"""A set of components and a corresponding adapatable reasoning matrices. r"""A set of components and a corresponding adapatable reasoning matrices.
Every component has its own reasoning matrix. Every component has its own reasoning matrix.
@@ -310,13 +309,12 @@ class ReasoningComponents(AbstractComponents):
""" """
def __init__( def __init__(
self, self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer, components_initializer: AbstractComponentsInitializer,
reasonings_initializer: reasonings_initializer:
AbstractReasoningsInitializer = PurePositiveReasoningsInitializer(), AbstractReasoningsInitializer = PurePositiveReasoningsInitializer()):
**kwargs): super().__init__()
super().__init__(**kwargs)
self.add_components(distribution, components_initializer, self.add_components(distribution, components_initializer,
reasonings_initializer) reasonings_initializer)

View File

@@ -41,9 +41,6 @@ def euclidean_distance_v2(x, y):
# batch diagonal. See: # batch diagonal. See:
# https://pytorch.org/docs/stable/generated/torch.diagonal.html # https://pytorch.org/docs/stable/generated/torch.diagonal.html
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1) distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
# print(f"{diff.shape=}") # (nx, ny, ndim)
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
# print(f"{distances.shape=}") # (nx, ny)
return distances return distances

View File

@@ -3,7 +3,11 @@
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import Union from typing import (
Callable,
Type,
Union,
)
import torch import torch
@@ -110,9 +114,9 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
""" """
def __init__(self, def __init__(self,
data: torch.TensorType, data: torch.Tensor,
noise: float = 0.0, noise: float = 0.0,
transform: callable = torch.nn.Identity()): transform: Callable = torch.nn.Identity()):
self.data = data self.data = data
self.noise = noise self.noise = noise
self.transform = transform self.transform = transform
@@ -151,14 +155,14 @@ class SelectionCompInitializer(AbstractDataAwareCompInitializer):
class MeanCompInitializer(AbstractDataAwareCompInitializer): class MeanCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by computing the mean of the provided data.""" """Generate components by computing the mean of the provided data."""
def generate(self, num_components: int): def generate(self, num_components: int):
mean = torch.mean(self.data, dim=0) mean = self.data.mean(dim=0)
repeat_dim = [num_components] + [1] * len(mean.shape) repeat_dim = [num_components] + [1] * len(mean.shape)
samples = mean.repeat(repeat_dim) samples = mean.repeat(repeat_dim)
components = self.generate_end_hook(samples) components = self.generate_end_hook(samples)
return components return components
class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer): class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all class-aware components initializers. """Abstract class for all class-aware components initializers.
Components generated by class-aware components initializers inherit the shape Components generated by class-aware components initializers inherit the shape
@@ -171,13 +175,18 @@ class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
def __init__(self, def __init__(self,
data, data,
noise: float = 0.0, noise: float = 0.0,
transform: callable = torch.nn.Identity()): transform: Callable = torch.nn.Identity()):
self.data, self.targets = parse_data_arg(data) self.data, self.targets = parse_data_arg(data)
self.noise = noise self.noise = noise
self.transform = transform self.transform = transform
self.clabels = torch.unique(self.targets).int().tolist() self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels) self.num_classes = len(self.clabels)
def generate_end_hook(self, samples):
drift = torch.rand_like(samples) * self.noise
components = self.transform(samples + drift)
return components
@abstractmethod @abstractmethod
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
... ...
@@ -200,7 +209,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
"""Abstract class for all stratified components initializers.""" """Abstract class for all stratified components initializers."""
@property @property
@abstractmethod @abstractmethod
def subinit_type(self) -> AbstractDataAwareCompInitializer: def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
... ...
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
@@ -276,10 +285,10 @@ class LabelsInitializer(AbstractLabelsInitializer):
"""Generate labels from `distribution`.""" """Generate labels from `distribution`."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution) distribution = parse_distribution(distribution)
labels = [] labels_list = []
for k, v in distribution.items(): for k, v in distribution.items():
labels.extend([k] * v) labels_list.extend([k] * v)
labels = torch.LongTensor(labels) labels = torch.LongTensor(labels_list)
return labels return labels
@@ -294,17 +303,18 @@ class OneHotLabelsInitializer(LabelsInitializer):
# Reasonings # Reasonings
def compute_distribution_shape(distribution):
distribution = parse_distribution(distribution)
num_components = sum(distribution.values())
num_classes = len(distribution.keys())
return (num_components, num_classes, 2)
class AbstractReasoningsInitializer(ABC): class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers.""" """Abstract class for all reasonings initializers."""
def __init__(self, components_first: bool = True): def __init__(self, components_first: bool = True):
self.components_first = components_first self.components_first = components_first
def compute_shape(self, distribution):
distribution = parse_distribution(distribution)
num_components = sum(distribution.values())
num_classes = len(distribution.keys())
return (num_components, num_classes, 2)
def generate_end_hook(self, reasonings): def generate_end_hook(self, reasonings):
if not self.components_first: if not self.components_first:
reasonings = reasonings.permute(2, 1, 0) reasonings = reasonings.permute(2, 1, 0)
@@ -340,7 +350,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
class ZerosReasoningsInitializer(AbstractReasoningsInitializer): class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with zeros.""" """Reasonings are all initialized with zeros."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
shape = self.compute_shape(distribution) shape = compute_distribution_shape(distribution)
reasonings = torch.zeros(*shape) reasonings = torch.zeros(*shape)
reasonings = self.generate_end_hook(reasonings) reasonings = self.generate_end_hook(reasonings)
return reasonings return reasonings
@@ -349,7 +359,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
class OnesReasoningsInitializer(AbstractReasoningsInitializer): class OnesReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with ones.""" """Reasonings are all initialized with ones."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
shape = self.compute_shape(distribution) shape = compute_distribution_shape(distribution)
reasonings = torch.ones(*shape) reasonings = torch.ones(*shape)
reasonings = self.generate_end_hook(reasonings) reasonings = self.generate_end_hook(reasonings)
return reasonings return reasonings
@@ -363,7 +373,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
self.maximum = maximum self.maximum = maximum
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
shape = self.compute_shape(distribution) shape = compute_distribution_shape(distribution)
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum) reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
reasonings = self.generate_end_hook(reasonings) reasonings = self.generate_end_hook(reasonings)
return reasonings return reasonings
@@ -372,7 +382,8 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
"""Each component reasons positively for exactly one class.""" """Each component reasons positively for exactly one class."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
num_components, num_classes, _ = self.compute_shape(distribution) num_components, num_classes, _ = compute_distribution_shape(
distribution)
A = OneHotLabelsInitializer().generate(distribution) A = OneHotLabelsInitializer().generate(distribution)
B = torch.zeros(num_components, num_classes) B = torch.zeros(num_components, num_classes)
reasonings = torch.stack([A, B], dim=-1) reasonings = torch.stack([A, B], dim=-1)
@@ -425,6 +436,33 @@ class EyeTransformInitializer(AbstractLinearTransformInitializer):
return self.generate_end_hook(weights) return self.generate_end_hook(weights)
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
"""Abstract class for all data-aware linear transform initializers."""
def __init__(self,
data: torch.Tensor,
noise: float = 0.0,
transform: Callable = torch.nn.Identity(),
out_dim_first: bool = False):
super().__init__(out_dim_first)
self.data = data
self.noise = noise
self.transform = transform
def generate_end_hook(self, weights: torch.Tensor):
drift = torch.rand_like(weights) * self.noise
weights = self.transform(weights + drift)
if self.out_dim_first:
weights = weights.permute(1, 0)
return weights
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
"""Initialize a matrix with Eigenvectors from the data."""
def generate(self, in_dim: int, out_dim: int):
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
return self.generate_end_hook(weights)
# Aliases - Components # Aliases - Components
CACI = ClassAwareCompInitializer CACI = ClassAwareCompInitializer
DACI = DataAwareCompInitializer DACI = DataAwareCompInitializer
@@ -456,3 +494,4 @@ ZRI = ZerosReasoningsInitializer
Eye = EyeTransformInitializer Eye = EyeTransformInitializer
OLTI = OnesLinearTransformInitializer OLTI = OnesLinearTransformInitializer
ZLTI = ZerosLinearTransformInitializer ZLTI = ZerosLinearTransformInitializer
PCALTI = PCALinearTransformInitializer

View File

@@ -106,17 +106,16 @@ def margin_loss(y_pred, y_true, margin=0.3):
class GLVQLoss(torch.nn.Module): class GLVQLoss(torch.nn.Module):
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs): def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.margin = margin self.margin = margin
self.squashing = get_activation(squashing) self.transfer_fn = get_activation(transfer_fn)
self.beta = torch.tensor(beta) self.beta = torch.tensor(beta)
def forward(self, outputs, targets): def forward(self, outputs, targets, plabels):
distances, plabels = outputs mu = glvq_loss(outputs, targets, prototype_labels=plabels)
mu = glvq_loss(distances, targets, prototype_labels=plabels) batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
batch_loss = self.squashing(mu + self.margin, beta=self.beta) return batch_loss.sum()
return torch.sum(batch_loss, dim=0)
class MarginLoss(torch.nn.modules.loss._Loss): class MarginLoss(torch.nn.modules.loss._Loss):

View File

@@ -82,23 +82,23 @@ def stratified_prod_pooling(values: torch.Tensor,
class StratifiedSumPooling(torch.nn.Module): class StratifiedSumPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_sum_pooling` function.""" """Thin wrapper over the `stratified_sum_pooling` function."""
def forward(self, values, labels): def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_sum_pooling(values, labels) return stratified_sum_pooling(values, labels)
class StratifiedProdPooling(torch.nn.Module): class StratifiedProdPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_prod_pooling` function.""" """Thin wrapper over the `stratified_prod_pooling` function."""
def forward(self, values, labels): def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_prod_pooling(values, labels) return stratified_prod_pooling(values, labels)
class StratifiedMinPooling(torch.nn.Module): class StratifiedMinPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_min_pooling` function.""" """Thin wrapper over the `stratified_min_pooling` function."""
def forward(self, values, labels): def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_min_pooling(values, labels) return stratified_min_pooling(values, labels)
class StratifiedMaxPooling(torch.nn.Module): class StratifiedMaxPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_max_pooling` function.""" """Thin wrapper over the `stratified_max_pooling` function."""
def forward(self, values, labels): def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_max_pooling(values, labels) return stratified_max_pooling(values, labels)

View File

@@ -11,13 +11,12 @@ from .initializers import (
class LinearTransform(torch.nn.Module): class LinearTransform(torch.nn.Module):
def __init__( def __init__(
self, self,
in_dim: int, in_dim: int,
out_dim: int, out_dim: int,
initializer: initializer:
AbstractLinearTransformInitializer = EyeTransformInitializer(), AbstractLinearTransformInitializer = EyeTransformInitializer()):
**kwargs): super().__init__()
super().__init__(**kwargs)
self.set_weights(in_dim, out_dim, initializer) self.set_weights(in_dim, out_dim, initializer)
@property @property
@@ -37,7 +36,7 @@ class LinearTransform(torch.nn.Module):
self._register_weights(weights) self._register_weights(weights)
def forward(self, x): def forward(self, x):
return x @ self.weights.T return x @ self.weights
# Aliases # Aliases

View File

@@ -1,6 +1,6 @@
"""ProtoTorch datasets""" """ProtoTorch datasets"""
from .abstract import NumpyDataset from .abstract import CSVDataset, NumpyDataset
from .sklearn import ( from .sklearn import (
Blobs, Blobs,
Circles, Circles,
@@ -10,3 +10,4 @@ from .sklearn import (
) )
from .spiral import Spiral from .spiral import Spiral
from .tecator import Tecator from .tecator import Tecator
from .xor import XOR

View File

@@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
import os import os
import numpy as np
import torch import torch
@@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset):
self.targets = torch.LongTensor(targets) self.targets = torch.LongTensor(targets)
tensors = [self.data, self.targets] tensors = [self.data, self.targets]
super().__init__(*tensors) super().__init__(*tensors)
class CSVDataset(NumpyDataset):
"""Create a Dataset from a CSV file."""
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
raw = np.genfromtxt(
filepath,
delimiter=delimiter,
skip_header=skip_header,
)
data = np.delete(raw, 1, target_col)
targets = raw[:, target_col]
super().__init__(data, targets)

View File

@@ -8,11 +8,11 @@ URL:
import warnings import warnings
from typing import Sequence, Union from typing import Sequence, Union
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import (load_iris, make_blobs, make_circles, from sklearn.datasets import (load_iris, make_blobs, make_circles,
make_classification, make_moons) make_classification, make_moons)
from prototorch.datasets.abstract import NumpyDataset
class Iris(NumpyDataset): class Iris(NumpyDataset):
"""Iris Dataset by Ronald Fisher introduced in 1936. """Iris Dataset by Ronald Fisher introduced in 1936.

View File

@@ -40,9 +40,10 @@ import os
import numpy as np import numpy as np
import torch import torch
from prototorch.datasets.abstract import ProtoDataset
from torchvision.datasets.utils import download_file_from_google_drive from torchvision.datasets.utils import download_file_from_google_drive
from prototorch.datasets.abstract import ProtoDataset
class Tecator(ProtoDataset): class Tecator(ProtoDataset):
""" """

View File

@@ -0,0 +1,18 @@
"""Exclusive-or (XOR) dataset for binary classification."""
import torch
def make_xor(num_samples=500):
x = torch.rand(num_samples, 2)
y = torch.zeros(num_samples)
y[torch.logical_and(x[:, 0] > 0.5, x[:, 1] < 0.5)] = 1
y[torch.logical_and(x[:, 1] > 0.5, x[:, 0] < 0.5)] = 1
return x, y
class XOR(torch.utils.data.TensorDataset):
"""Exclusive-or (XOR) dataset for binary classification."""
def __init__(self, num_samples: int = 500):
x, y = make_xor(num_samples)
super().__init__(x, y)

View File

@@ -1,7 +1,12 @@
"""ProtoFlow utilities""" """ProtoFlow utilities"""
import warnings import warnings
from typing import Union from typing import (
Dict,
Iterable,
List,
Union,
)
import numpy as np import numpy as np
import torch import torch
@@ -23,15 +28,15 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
return mesh, xx, yy return mesh, xx, yy
def distribution_from_list(list_dist: list[int], clabels: list[int] = []): def distribution_from_list(list_dist: List[int],
clabels: Iterable[int] = None):
clabels = clabels or list(range(len(list_dist))) clabels = clabels or list(range(len(list_dist)))
distribution = dict(zip(clabels, list_dist)) distribution = dict(zip(clabels, list_dist))
return distribution return distribution
def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str], def parse_distribution(user_distribution,
list[int], tuple[int]], clabels: Iterable[int] = None) -> Dict[int, int]:
clabels: list[int] = []) -> dict[int, int]:
"""Parse user-provided distribution. """Parse user-provided distribution.
Return a dictionary with integer keys that represent the class labels and Return a dictionary with integer keys that represent the class labels and
@@ -75,9 +80,13 @@ def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]): def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
"""Return data and target as torch tensors.""" """Return data and target as torch tensors."""
if isinstance(data_arg, Dataset): if isinstance(data_arg, Dataset):
ds_size = len(data_arg) if hasattr(data_arg, "__len__"):
loader = DataLoader(data_arg, batch_size=ds_size) ds_size = len(data_arg) # type: ignore
data, targets = next(iter(loader)) loader = DataLoader(data_arg, batch_size=ds_size)
data, targets = next(iter(loader))
else:
emsg = f"Dataset {data_arg} is not sized (`__len__` unimplemented)."
raise TypeError(emsg)
elif isinstance(data_arg, DataLoader): elif isinstance(data_arg, DataLoader):
data = torch.tensor([]) data = torch.tensor([])

View File

@@ -1,10 +1,12 @@
""" """
_____ _ _______ _
| __ \ | | |__ __| | | ######
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__ # # ##### #### ##### #### ##### #### ##### #### # #
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \ # # # # # # # # # # # # # # # # # #
| | | | | (_) | || (_) | | (_) | | | (__| | | | ###### # # # # # # # # # # # # # ######
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_| # ##### # # # # # # # # ##### # # #
# # # # # # # # # # # # # # # # #
# # # #### # #### # #### # # #### # #
ProtoTorch Core Package ProtoTorch Core Package
""" """
@@ -18,7 +20,7 @@ with open("README.md", "r") as fh:
INSTALL_REQUIRES = [ INSTALL_REQUIRES = [
"torch>=1.3.1", "torch>=1.3.1",
"torchvision>=0.5.0", "torchvision>=0.7.1",
"numpy>=1.9.1", "numpy>=1.9.1",
"sklearn", "sklearn",
] ]
@@ -26,7 +28,10 @@ DATASETS = [
"requests", "requests",
"tqdm", "tqdm",
] ]
DEV = ["bumpversion"] DEV = [
"bumpversion",
"pre-commit",
]
DOCS = [ DOCS = [
"recommonmark", "recommonmark",
"sphinx", "sphinx",
@@ -38,12 +43,15 @@ EXAMPLES = [
"matplotlib", "matplotlib",
"torchinfo", "torchinfo",
] ]
TESTS = ["codecov", "pytest"] TESTS = [
"pytest-cov",
"pytest",
]
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
name="prototorch", name="prototorch",
version="0.5.0", version="0.7.1",
description="Highly extensible, GPU-supported " description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox " "Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.", "built using PyTorch and its nn API.",
@@ -54,30 +62,33 @@ setup(
url=PROJECT_URL, url=PROJECT_URL,
download_url=DOWNLOAD_URL, download_url=DOWNLOAD_URL,
license="MIT", license="MIT",
python_requires=">=3.6",
install_requires=INSTALL_REQUIRES, install_requires=INSTALL_REQUIRES,
extras_require={ extras_require={
"docs": DOCS,
"datasets": DATASETS, "datasets": DATASETS,
"dev": DEV,
"docs": DOCS,
"examples": EXAMPLES, "examples": EXAMPLES,
"tests": TESTS, "tests": TESTS,
"all": ALL, "all": ALL,
}, },
classifiers=[ classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Environment :: Console", "Environment :: Console",
"Natural Language :: English",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"Intended Audience :: Education", "Intended Audience :: Education",
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Natural Language :: English", "Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
], ],
packages=find_packages(), packages=find_packages(),
zip_safe=False, zip_safe=False,

View File

@@ -265,13 +265,25 @@ def test_eye_transform_init_wide():
# Transforms # Transforms
def test_linear_transform(): def test_linear_transform_default_eye_init():
l = pt.transforms.LinearTransform(2, 4) l = pt.transforms.LinearTransform(2, 4)
actual = l.weights actual = l.weights
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]]) desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
assert torch.allclose(actual, desired) assert torch.allclose(actual, desired)
def test_linear_transform_forward():
l = pt.transforms.LinearTransform(4, 2)
actual_weights = l.weights
desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]])
assert torch.allclose(actual_weights, desired_weights)
actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \
[1.1, 2.2, 3.3, 4.4], \
[5.5, 6.6, 7.7, 8.8]]))
desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]])
assert torch.allclose(actual_outputs, desired_outputs)
def test_linear_transform_zeros_init(): def test_linear_transform_zeros_init():
l = pt.transforms.LinearTransform( l = pt.transforms.LinearTransform(
in_dim=2, in_dim=2,

View File

@@ -49,6 +49,23 @@ class TestNumpyDataset(unittest.TestCase):
self.assertEqual(len(ds), 3) self.assertEqual(len(ds), 3)
class TestCSVDataset(unittest.TestCase):
def setUp(self):
data = np.random.rand(100, 4)
targets = np.random.randint(2, size=(100, 1))
arr = np.hstack([data, targets])
if not os.path.exists("./artifacts"):
os.mkdir("./artifacts")
np.savetxt("./artifacts/test.csv", arr, delimiter=",")
def test_len(self):
ds = pt.datasets.CSVDataset("./artifacts/test.csv")
self.assertEqual(len(ds), 100)
def tearDown(self):
os.remove("./artifacts/test.csv")
class TestSpiral(unittest.TestCase): class TestSpiral(unittest.TestCase):
def test_init(self): def test_init(self):
ds = pt.datasets.Spiral(num_samples=10) ds = pt.datasets.Spiral(num_samples=10)
@@ -94,67 +111,67 @@ class TestMoons(unittest.TestCase):
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)
class TestTecator(unittest.TestCase): # class TestTecator(unittest.TestCase):
def setUp(self): # def setUp(self):
self.artifacts_dir = "./artifacts/Tecator" # self.artifacts_dir = "./artifacts/Tecator"
self._remove_artifacts() # self._remove_artifacts()
def _remove_artifacts(self): # def _remove_artifacts(self):
if os.path.exists(self.artifacts_dir): # if os.path.exists(self.artifacts_dir):
shutil.rmtree(self.artifacts_dir) # shutil.rmtree(self.artifacts_dir)
def test_download_false(self): # def test_download_false(self):
rootdir = self.artifacts_dir.rpartition("/")[0] # rootdir = self.artifacts_dir.rpartition("/")[0]
self._remove_artifacts() # self._remove_artifacts()
with self.assertRaises(RuntimeError): # with self.assertRaises(RuntimeError):
_ = pt.datasets.Tecator(rootdir, download=False) # _ = pt.datasets.Tecator(rootdir, download=False)
def test_download_caching(self): # def test_download_caching(self):
rootdir = self.artifacts_dir.rpartition("/")[0] # rootdir = self.artifacts_dir.rpartition("/")[0]
_ = pt.datasets.Tecator(rootdir, download=True, verbose=False) # _ = pt.datasets.Tecator(rootdir, download=True, verbose=False)
_ = pt.datasets.Tecator(rootdir, download=False, verbose=False) # _ = pt.datasets.Tecator(rootdir, download=False, verbose=False)
def test_repr(self): # def test_repr(self):
rootdir = self.artifacts_dir.rpartition("/")[0] # rootdir = self.artifacts_dir.rpartition("/")[0]
train = pt.datasets.Tecator(rootdir, download=True, verbose=True) # train = pt.datasets.Tecator(rootdir, download=True, verbose=True)
self.assertTrue("Split: Train" in train.__repr__()) # self.assertTrue("Split: Train" in train.__repr__())
def test_download_train(self): # def test_download_train(self):
rootdir = self.artifacts_dir.rpartition("/")[0] # rootdir = self.artifacts_dir.rpartition("/")[0]
train = pt.datasets.Tecator(root=rootdir, # train = pt.datasets.Tecator(root=rootdir,
train=True, # train=True,
download=True, # download=True,
verbose=False) # verbose=False)
train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False) # train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False)
x_train, y_train = train.data, train.targets # x_train, y_train = train.data, train.targets
self.assertEqual(x_train.shape[0], 144) # self.assertEqual(x_train.shape[0], 144)
self.assertEqual(y_train.shape[0], 144) # self.assertEqual(y_train.shape[0], 144)
self.assertEqual(x_train.shape[1], 100) # self.assertEqual(x_train.shape[1], 100)
def test_download_test(self): # def test_download_test(self):
rootdir = self.artifacts_dir.rpartition("/")[0] # rootdir = self.artifacts_dir.rpartition("/")[0]
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) # test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
x_test, y_test = test.data, test.targets # x_test, y_test = test.data, test.targets
self.assertEqual(x_test.shape[0], 71) # self.assertEqual(x_test.shape[0], 71)
self.assertEqual(y_test.shape[0], 71) # self.assertEqual(y_test.shape[0], 71)
self.assertEqual(x_test.shape[1], 100) # self.assertEqual(x_test.shape[1], 100)
def test_class_to_idx(self): # def test_class_to_idx(self):
rootdir = self.artifacts_dir.rpartition("/")[0] # rootdir = self.artifacts_dir.rpartition("/")[0]
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) # test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
_ = test.class_to_idx # _ = test.class_to_idx
def test_getitem(self): # def test_getitem(self):
rootdir = self.artifacts_dir.rpartition("/")[0] # rootdir = self.artifacts_dir.rpartition("/")[0]
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) # test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
x, y = test[0] # x, y = test[0]
self.assertEqual(x.shape[0], 100) # self.assertEqual(x.shape[0], 100)
self.assertIsInstance(y, int) # self.assertIsInstance(y, int)
def test_loadable_with_dataloader(self): # def test_loadable_with_dataloader(self):
rootdir = self.artifacts_dir.rpartition("/")[0] # rootdir = self.artifacts_dir.rpartition("/")[0]
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False) # test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True) # _ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
def tearDown(self): # def tearDown(self):
self._remove_artifacts() # self._remove_artifacts()