88 Commits

Author SHA1 Message Date
Alexander Engelsberger
ae75b9ebf7 Bump version: 0.2.0 → 0.3.0-dev0 2021-04-21 14:57:45 +02:00
Alexander Engelsberger
34973808b8 Improve documentation. 2021-04-21 14:55:54 +02:00
Alexander Engelsberger
c42df6e203 Merge version 0.2.0 into feature/plugin-architecture. 2021-04-19 16:44:26 +02:00
Jensun Ravichandran
101b50f4e6 Update prototypes.py
Changes:
1. Change single-quotes to double-quotes.
2021-04-15 12:35:06 +02:00
Jensun Ravichandran
db842b79bb Bump version: 0.1.1-rc0 → 0.2.0 2021-04-14 19:21:14 +02:00
Jensun Ravichandran
98a8fc52fa Add docs 2021-04-14 19:20:08 +02:00
Jensun Ravichandran
6796ec494f Merge pull request #1 from ChristophRaab/dev
gtlvq
2021-04-14 16:18:30 +02:00
Alexander Engelsberger
cd9303267b Use git version. 2021-04-14 13:48:00 +02:00
Alexander Engelsberger
599dfc3fda Fix issue with plugin subpackage import. 2021-04-13 22:55:49 +02:00
Alexander Engelsberger
5b2ab34232 Add plugin loader. 2021-04-13 12:36:22 +02:00
Jensun Ravichandran
429570323e Update iris example 2021-03-26 16:06:11 +01:00
Jensun Ravichandran
3edb13baf4 Update examples/glvq_iris.py script 2021-03-01 18:52:54 +01:00
Jensun Ravichandran
42cedbb2b8 Fix imports in examples/gmlvq_tecator.py 2021-03-01 18:45:41 +01:00
Jensun Ravichandran
2322876eb6 Update .travis.yml 2021-02-10 17:04:04 +01:00
Jensun Ravichandran
bc7df1059f Add utils folder with color utils 2021-02-10 17:03:12 +01:00
Jensun Ravichandran
4c7c9cc34a Update setup.py and README.md 2021-02-10 17:02:02 +01:00
Christoph
e39f307194 Another Codacy bug fix 2021-01-14 11:27:20 +01:00
Christoph
e2867f696e Anoter Codacy bug fix 2021-01-14 11:18:25 +01:00
Christoph
30dc0ea8b1 Codacy Bug Report fixes 2021-01-14 10:04:43 +01:00
Christoph
895281aabd gtlvq 2021-01-12 18:11:46 +01:00
Jensun Ravichandran
a55320a65b Add local gmlvq example 2020-09-24 16:59:42 +02:00
Jensun Ravichandran
559f4acc73 Update readme 2020-09-24 12:01:50 +02:00
Jensun Ravichandran
9b5bccc39d Update readme 2020-09-24 11:54:32 +02:00
Jensun Ravichandran
a8a99f6971 Update iris example 2020-09-24 11:54:18 +02:00
Jensun Ravichandran
58efa5a4cf Fix things codacy complains about 2020-09-24 11:53:35 +02:00
Jensun Ravichandran
9672aab8e2 Add codacy config file 2020-09-24 11:27:56 +02:00
Jensun Ravichandran
d5ab9c3771 Fix divide-by-zero in example 2020-09-23 15:29:26 +02:00
Jensun Ravichandran
3e6aa6a20b Update example 2020-08-04 11:30:50 +02:00
Jensun Ravichandran
b138277608 Fix int fill-value error in test_modules.py 2020-07-30 11:42:37 +02:00
Jensun Ravichandran
9ccbec52f7 Update install requirements and readme 2020-07-30 11:19:02 +02:00
Jensun Ravichandran
cd652508b9 Update manifest 2020-07-13 09:32:38 +02:00
Jensun Ravichandran
fa72c7156e Update tests/test_modules.py 2020-07-13 09:32:12 +02:00
Jensun Ravichandran
6e72b9267a Add siamese example using GMLVQ and Tecator 2020-07-13 09:31:48 +02:00
blackfly
8a4a596035 Make prototype_labels non-trainable Parameters 2020-04-27 13:39:27 +02:00
blackfly
0cfbc0473b Bump version: 0.1.1-dev0 → 0.1.1-rc0 2020-04-27 12:56:42 +02:00
blackfly
cf0659d881 Add test cases to test newly added features 2020-04-27 12:49:54 +02:00
blackfly
d17b9a3346 Modify stratified_min function 2020-04-27 12:48:12 +02:00
blackfly
532f63b1de Add one-hot support in functions/initializers.py 2020-04-27 12:47:44 +02:00
blackfly
c11a3860df Refactor functions/losses.py 2020-04-27 12:47:15 +02:00
blackfly
dab91e471a Add minor cosmetic changes 2020-04-27 12:45:42 +02:00
blackfly
a167565857 Update Prototypes1D 2020-04-27 12:44:19 +02:00
blackfly
e063625486 Remove some requirements from requirements.txt 2020-04-15 12:12:44 +02:00
blackfly
89eb5358a0 Try fixing tqdm AttributeError 2020-04-14 20:26:49 +02:00
blackfly
5c59515128 Update github action 'tests' 2020-04-14 20:19:23 +02:00
blackfly
7eb7a6b194 Update .travis.yml 2020-04-14 20:19:15 +02:00
blackfly
5811c4b9f9 Add requirements.txt 2020-04-14 20:18:45 +02:00
blackfly
7b1887d56e Add 'requests' requirements for downloading datasets 2020-04-14 20:04:10 +02:00
blackfly
63a25e7a38 Refactor examples/glvq_iris.py 2020-04-14 19:57:19 +02:00
blackfly
a0f20a40f6 Add test cases to test recently added features 2020-04-14 19:53:51 +02:00
blackfly
88cbe0a126 Add alias for squared_euclidean_distance 2020-04-14 19:53:26 +02:00
blackfly
a3548e0ddd Add stratified_min competition function 2020-04-14 19:52:59 +02:00
blackfly
3cfbc49254 Fix generator bug in stratified_random initializer 2020-04-14 19:51:54 +02:00
blackfly
2b82830590 Add 'datasets' to main package __init__.py 2020-04-14 19:51:14 +02:00
blackfly
553b1e1a65 Refactor datasets and use float32 instead of float64 in Tecator 2020-04-14 19:49:59 +02:00
blackfly
a9d2855323 Refactor prototypes module and begin documentation 2020-04-14 19:48:46 +02:00
blackfly
cf7d7b5d9d Add tests/test_datasets.py 2020-04-14 19:47:59 +02:00
blackfly
a22c752342 Add prototorch/datasets 2020-04-14 19:47:34 +02:00
blackfly
4158586cb9 More cosmetic changes 2020-04-11 18:12:37 +02:00
blackfly
f80d9648c3 Minor cosmetic changes 2020-04-11 17:35:32 +02:00
blackfly
e54bf07030 Populate init files 2020-04-11 17:35:00 +02:00
blackfly
8c629c0cb1 Fix a bunch of codacy code-style issues 2020-04-11 15:47:26 +02:00
blackfly
8f3a43f62a Remove assert statements following codacy security recommendation
"Use of assert detected. The enclosed code will be removed when compiling to
optimised byte code."
2020-04-11 15:45:29 +02:00
blackfly
955661af95 Remove utils import from prototorch/__init__.py 2020-04-11 15:12:53 +02:00
blackfly
c54d14c55e Remove datasets import from prototorch/__init__.py 2020-04-11 14:59:11 +02:00
blackfly
6090aad176 Update examples/glvq_iris.py to use the recently modified API 2020-04-11 14:29:06 +02:00
blackfly
1ec7bd261b Add small API changes and more test cases 2020-04-11 14:28:22 +02:00
blackfly
da3b0cc262 Update RELEASE.md 2020-04-11 14:26:05 +02:00
blackfly
f640a22cf2 Rename input to x in activation functions 2020-04-11 14:25:35 +02:00
blackfly
c843ace63d Update README.md 2020-04-11 14:22:34 +02:00
blackfly
242c9de3b6 Fix codecov reporting in .travis.yml 2020-04-08 23:37:11 +02:00
blackfly
438a5b9360 Bump version: 0.1.0-rc0 → 0.1.1-dev0 2020-04-08 23:00:34 +02:00
blackfly
f98f3d095e Update .travis.yml to cache artifacts from test scripts 2020-04-08 22:47:31 +02:00
blackfly
21b0279839 Add test cases 2020-04-08 22:47:08 +02:00
blackfly
b19cbcb76a Fix zero-distance bug in glvq_loss 2020-04-08 22:46:08 +02:00
blackfly
7d5ab81dbf Clean up prototorch/functions/distances.py 2020-04-08 22:44:02 +02:00
blackfly
bde408a80e Prepare activation and competition functions for TorchScript 2020-04-08 22:42:56 +02:00
blackfly
900955d67a Rename tests github action 2020-04-08 22:34:26 +02:00
blackfly
3757c937b3 Bump version: 0.1.0-dev0 → 0.1.0-rc0 2020-04-06 21:49:52 +02:00
blackfly
38f637aaeb Add build status batch from travis 2020-04-06 21:38:47 +02:00
blackfly
6ddfe48a95 Use bionic distribution instead of trusty
Downloading archive: ...binaries/ubuntu/14.04/x86_64/python-3.8.tar.bz2
$ curl -sSf --retry 5 -o python-3.8.tar.bz2 ${archive_url}
curl: (22) The requested URL returned error: 404 Not Found
Unable to download 3.8 archive. The archive may not exist.
Please consider a different version.
2020-04-06 21:21:14 +02:00
blackfly
bf0e694321 Add missing torch dependency in travis.yml 2020-04-06 21:16:43 +02:00
blackfly
e2c9848120 Update tox.ini to use coverage 2020-04-06 21:05:57 +02:00
blackfly
dc60b7e5b5 Add .travis.yml 2020-04-06 21:05:20 +02:00
blackfly
c21913fdd4 Add tests/__init__.py
Adding the __init__.py file makes it possible to run `coverage run -m pytest`
from the project root.
2020-04-06 21:01:50 +02:00
blackfly
59e31f94ab Add more version badges and bibtex section to README.md 2020-04-06 19:59:52 +02:00
blackfly
cddefa9b0d Add RELEASE.md 2020-04-06 18:52:12 +02:00
blackfly
26d71fdd60 Add version badges to README.md 2020-04-06 18:48:02 +02:00
blackfly
ced8f532dd Update MANIFEST.in to include codecov and test scripts 2020-04-06 18:32:06 +02:00
44 changed files with 2596 additions and 400 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.0-dev0
current_version = 0.3.0-dev0
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
@@ -19,3 +19,4 @@ values =
[bumpversion:file:./prototorch/__init__.py]
[bumpversion:file:./docs/source/conf.py]

15
.codacy.yml Normal file
View File

@@ -0,0 +1,15 @@
# To validate the contents of your configuration file
# run the following command in the folder where the configuration file is located:
# codacy-analysis-cli validate-configuration --directory `pwd`
# To analyse, run:
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
---
engines:
pylintpython3:
exclude_paths:
- config/engines.yml
remark-lint:
exclude_paths:
- config/engines.yml
exclude_paths:
- 'tests/**'

View File

@@ -1,7 +1,7 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Tests
name: tests
on:
push:
@@ -24,6 +24,9 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .
- name: Install extras
run: |
pip install -r requirements.txt
- name: Lint with flake8
run: |
pip install flake8

27
.readthedocs.yml Normal file
View File

@@ -0,0 +1,27 @@
# .readthedocs.yml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
fail_on_warning: true
# Build documentation with MkDocs
# mkdocs:
# configuration: mkdocs.yml
# Optionally build your docs in additional formats such as PDF and ePub
formats: all
# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.8
install:
- method: pip
path: .
extra_requirements:
- all

36
.travis.yml Normal file
View File

@@ -0,0 +1,36 @@
dist: bionic
sudo: false
language: python
python: 3.8
cache:
directories:
- "./tests/artifacts"
# - "$HOME/.prototorch/datasets"
install:
- pip install . --progress-bar off
- pip install -r requirements.txt
# Generate code coverage report
script:
- coverage run -m pytest
# Push the results to codecov
after_success:
- bash <(curl -s https://codecov.io/bash)
# Publish on PyPI
deploy:
provider: pypi
username: __token__
password:
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
on:
tags: true
skip_existing: true
# The password is encrypted with:
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
# See https://docs.travis-ci.com/user/deployment/pypi and
# https://github.com/travis-ci/travis.rb#installation
# for more details
# Note: The encrypt command does not work well in ZSH.

View File

@@ -1,9 +1,13 @@
include .bumpversion.cfg
include LICENSE
include tox.ini
include *.md
include *.txt
include *.yml
recursive-include docs *.bat
recursive-include docs *.png
recursive-include docs *.py
recursive-include docs *.rst
recursive-include docs Makefile
recursive-include examples *.py
recursive-include tests *.py

View File

@@ -1,49 +1,62 @@
# ProtoTorch
# ProtoTorch: Prototype Learning in PyTorch
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge research in
prototype-based machine learning algorithms.
![ProtoTorch Logo](https://prototorch.readthedocs.io/en/latest/_static/horizontal-lockup.png)
![Tests](https://github.com/si-cim/prototorch/workflows/Tests/badge.svg)
[![Build Status](https://travis-ci.org/si-cim/prototorch.svg?branch=master)](https://travis-ci.org/si-cim/prototorch)
![tests](https://github.com/si-cim/prototorch/workflows/tests/badge.svg)
[![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch?color=yellow&label=version)](https://github.com/si-cim/prototorch/releases)
[![PyPI](https://img.shields.io/pypi/v/prototorch)](https://pypi.org/project/prototorch/)
[![codecov](https://codecov.io/gh/si-cim/prototorch/branch/master/graph/badge.svg)](https://codecov.io/gh/si-cim/prototorch)
[![Codacy Badge](https://api.codacy.com/project/badge/Grade/76273904bf9343f0a8b29cd8aca242e7)](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=si-cim/prototorch&amp;utm_campaign=Badge_Grade)
![PyPI - Downloads](https://img.shields.io/pypi/dm/prototorch?color=blue)
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch)](https://github.com/si-cim/prototorch/blob/master/LICENSE)
*Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow)
## Description
This is a Python toolbox brewed at the Mittweida University of Applied Sciences
in Germany for bleeding-edge research in Learning Vector Quantization (LVQ)
and potentially other prototype-based methods. Although, there are
other (perhaps more extensive) LVQ toolboxes available out there, the focus of
ProtoTorch is ease-of-use, extensibility and speed.
Many popular prototype-based Machine Learning (ML) algorithms like K-Nearest
Neighbors (KNN), Generalized Learning Vector Quantization (GLVQ) and Generalized
Matrix Learning Vector Quantization (GMLVQ) are implemented using the "nn" API
provided by PyTorch.
in Germany for bleeding-edge research in Prototype-based Machine Learning
methods and other interpretable models. The focus of ProtoTorch is ease-of-use,
extensibility and speed.
## Installation
ProtoTorch can be installed using `pip`.
```bash
pip install -U prototorch
```
pip install prototorch
To also install the extras, use
```bash
pip install -U prototorch[all]
```
*Note: If you're using [ZSH](https://www.zsh.org/), the square brackets `[ ]`
have to be escaped like so: `\[\]`, making the install command `pip install -U
prototorch\[all\]`.*
To install the bleeding-edge features and improvements:
```
```bash
git clone https://github.com/si-cim/prototorch.git
git checkout dev
cd prototorch
pip install -e .
pip install -e .[all]
```
## Usage
## Documentation
ProtoTorch is modular. It is very easy to use the modular pieces provided by
ProtoTorch, like the layers, losses, callbacks and metrics to build your own
prototype-based(instance-based) models. These pieces blend-in seamlessly with
numpy and PyTorch to allow you mix and match the modules from ProtoTorch with
other PyTorch modules.
The documentation is available at <https://www.prototorch.ml/en/latest/>. Should
that link not work try <https://prototorch.readthedocs.io/en/latest/>.
ProtoTorch comes prepackaged with many popular LVQ algorithms in a convenient
API, with more algorithms and techniques coming soon. If you would simply like
to be able to use those algorithms to train large ML models on a GPU, ProtoTorch
lets you do this without requiring a black-belt in high-performance Tensor
computation.
## Bibtex
If you would like to cite the package, please use this:
```bibtex
@misc{Ravichandran2020b,
author = {Ravichandran, J},
title = {ProtoTorch},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/si-cim/prototorch}}
}

16
RELEASE.md Normal file
View File

@@ -0,0 +1,16 @@
# ProtoTorch Releases
## Release 0.2.0
### Includes
- Fixes in example scripts.
## Release 0.1.1-dev0
### Includes
- Minor bugfixes.
- 100% line coverage.
## Release 0.1.0-dev0
Initial public release of ProtoTorch.

20
docs/Makefile Normal file
View File

@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= python3 -m sphinx
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
docs/make.bat Normal file
View File

@@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

4
docs/requirements.txt Normal file
View File

@@ -0,0 +1,4 @@
torch==1.6.0
matplotlib==3.1.2
sphinx_rtd_theme==0.5.0
sphinxcontrib-katex==0.6.1

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

46
docs/source/api.rst Normal file
View File

@@ -0,0 +1,46 @@
.. ProtoFlow API Reference
ProtoFlow API Reference
======================================
Datasets
--------------------------------------
.. automodule:: prototorch.datasets
:members:
:undoc-members:
Functions
--------------------------------------
**Dimensions:**
- :math:`B` ... Batch size
- :math:`P` ... Number of prototypes
- :math:`n_x` ... Data dimension for vectorial data
- :math:`n_w` ... Data dimension for vectorial prototypes
Activations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: prototorch.functions.activations
:members:
:exclude-members: register_activation, get_activation
:undoc-members:
Distances
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: prototorch.functions.distances
:members:
:exclude-members: sed
:undoc-members:
Modules
--------------------------------------
.. automodule:: prototorch.modules
:members:
:undoc-members:
Utilities
--------------------------------------
.. automodule:: prototorch.utils
:members:
:undoc-members:

180
docs/source/conf.py Normal file
View File

@@ -0,0 +1,180 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath("../../"))
import sphinx_rtd_theme
# -- Project information -----------------------------------------------------
project = "ProtoTorch"
copyright = "2021, Jensun Ravichandran"
author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "0.3.0-dev0"
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
needs_sphinx = "1.6"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named "sphinx.ext.*") or your custom
# ones.
extensions = [
"recommonmark",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
"sphinx.ext.coverage",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx_rtd_theme",
"sphinxcontrib.katex",
]
# katex_prerender = True
katex_prerender = False
napoleon_use_ivar = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = [".rst", ".md"]
# The master toctree document.
master_doc = "index"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use. Choose from:
# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
# "monokai", "perldoc", "pastie", "borland", "trac", "native", "fruity", "bw",
# "vim", "vs", "tango", "rrt", "xcode", "igor", "paraiso-light", "paraiso-dark",
# "lovelace", "algol", "algol_nu", "arduino", "rainbo w_dash", "abap",
# "solarized-dark", "solarized-light", "sas", "stata", "stata-light",
# "stata-dark", "inkpot"]
pygments_style = "monokai"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
# Disable docstring inheritance
autodoc_inherit_docstrings = False
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
# https://sphinx-themes.org/
html_theme = "sphinx_rtd_theme"
html_logo = "_static/img/horizontal-lockup.png"
html_theme_options = {
"logo_only": True,
"display_version": True,
"prev_next_buttons_location": "bottom",
"style_external_links": False,
"style_nav_header_background": "#ffffff",
# Toc options
"collapse_navigation": True,
"sticky_navigation": True,
"navigation_depth": 4,
"includehidden": True,
"titles_only": False,
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
html_css_files = [
"https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.css",
]
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = "protoflowdoc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ("letterpaper" or "a4paper").
#
# "papersize": "letterpaper",
# The font size ("10pt", "11pt" or "12pt").
#
# "pointsize": "10pt",
# Additional stuff for the LaTeX preamble.
#
# "preamble": "",
# Latex figure (float) alignment
#
# "figure_align": "htbp",
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, "prototorch.tex", "ProtoTorch Documentation",
"Jensun Ravichandran", "manual"),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, "prototorch", "ProtoTorch Documentation", author, "prototorch",
"Prototype-based machine learning in PyTorch.",
"Miscellaneous"),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"python": ("https://docs.python.org/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
}
# -- Options for Epub output ----------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-epub-output
epub_cover = ()
version = release

22
docs/source/index.rst Normal file
View File

@@ -0,0 +1,22 @@
.. ProtoTorch documentation master file
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
About ProtoTorch
================
.. toctree::
:hidden:
:maxdepth: 3
:caption: Contents:
self
api
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge
research in prototype-based machine learning algorithms.
Indices
=======
* :ref:`genindex`
* :ref:`modindex`

View File

@@ -1,18 +1,19 @@
"""ProtoTorch GLVQ example using 2D Iris data"""
"""ProtoTorch GLVQ example using 2D Iris data."""
import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import AddPrototypes1D
from prototorch.modules.prototypes import Prototypes1D
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
# Prepare and preprocess the data
scaler = StandardScaler()
x_train, y_train = load_iris(True)
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
scaler.fit(x_train)
x_train = scaler.transform(x_train)
@@ -20,16 +21,19 @@ x_train = scaler.transform(x_train)
# Define the GLVQ model
class Model(torch.nn.Module):
def __init__(self, **kwargs):
def __init__(self):
"""GLVQ model for training on 2D Iris data."""
super().__init__()
self.p1 = AddPrototypes1D(input_dim=2,
prototypes_per_class=1,
nclasses=3,
prototype_initializer='zeros')
self.proto_layer = Prototypes1D(
input_dim=2,
prototypes_per_class=3,
nclasses=3,
prototype_initializer="stratified_random",
data=[x_train, y_train])
def forward(self, x):
protos = self.p1.prototypes
plabels = self.p1.prototype_labels
protos = self.proto_layer.prototypes
plabels = self.proto_layer.prototype_labels
dis = euclidean_distance(x, protos)
return dis, plabels
@@ -37,17 +41,28 @@ class Model(torch.nn.Module):
# Build the GLVQ model
model = Model()
# Print summary using torchinfo (might be buggy/incorrect)
print(summary(model))
# Optimize using SGD optimizer from `torch.optim`
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train)
# Training loop
fig = plt.figure('Prototype Visualization')
title = "Prototype Visualization"
fig = plt.figure(title)
for epoch in range(70):
# Compute loss.
distances, plabels = model(torch.tensor(x_train))
loss = criterion([distances, plabels], torch.tensor(y_train))
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():02.02f}')
# Compute loss
dis, plabels = model(x_in)
loss = criterion([dis, plabels], y_in)
with torch.no_grad():
pred = wtac(dis, plabels)
correct = pred.eq(y_in.view_as(pred)).sum().item()
acc = 100. * correct / len(x_train)
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%")
# Take a gradient descent step
optimizer.zero_grad()
@@ -55,44 +70,39 @@ for epoch in range(70):
optimizer.step()
# Get the prototypes form the model
protos = model.p1.prototypes.data.numpy()
protos = model.proto_layer.prototypes.data.numpy()
if np.isnan(np.sum(protos)):
print("Stopping training because of `nan` in prototypes.")
break
# Visualize the data and the prototypes
ax = fig.gca()
ax.cla()
cmap = 'viridis'
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
ax.set_title(title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(protos[:, 0],
protos[:, 1],
c=plabels,
cmap=cmap,
edgecolor='k',
marker='D',
edgecolor="k",
marker="D",
s=50)
# Paint decision regions
border = 1
resolution = 50
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min(), x[:, 0].max()
y_min, y_max = x[:, 1].min(), x[:, 1].max()
x_min, x_max = x_min - border, x_max + border
y_min, y_max = y_min - border, y_max + border
try:
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1.0 / resolution),
np.arange(y_min, y_max, 1.0 / resolution))
except ValueError as ve:
print(ve)
raise ValueError(f'x_min: {x_min}, x_max: {x_max}. '
f'x_min - x_max is {x_max - x_min}.')
except MemoryError as me:
print(me)
raise ValueError('Too many points. ' 'Try reducing the resolution.')
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
torch_input = torch.from_numpy(mesh_input)
torch_input = torch.Tensor(mesh_input)
d = model(torch_input)[0]
y_pred = np.argmin(d.detach().numpy(), axis=1)
w_indices = torch.argmin(d, dim=1)
y_pred = torch.index_select(plabels, 0, w_indices)
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions
@@ -100,4 +110,5 @@ for epoch in range(70):
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)

102
examples/gmlvq_tecator.py Normal file
View File

@@ -0,0 +1,102 @@
"""ProtoTorch "siamese" GMLVQ example using Tecator."""
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import get_legend_handles
# Prepare the dataset and dataloader
train_data = Tecator(root="./artifacts", train=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
class Model(torch.nn.Module):
def __init__(self, **kwargs):
"""GMLVQ model as a siamese network."""
super().__init__()
x, y = train_data.data, train_data.targets
self.p1 = Prototypes1D(input_dim=100,
prototypes_per_class=2,
nclasses=2,
prototype_initializer="stratified_random",
data=[x, y])
self.omega = torch.nn.Linear(in_features=100,
out_features=100,
bias=False)
torch.nn.init.eye_(self.omega.weight)
def forward(self, x):
protos = self.p1.prototypes
plabels = self.p1.prototype_labels
# Process `x` and `protos` through `omega`
x_map = self.omega(x)
protos_map = self.omega(protos)
# Compute distances and output
dis = sed(x_map, protos_map)
return dis, plabels
# Build the GLVQ model
model = Model()
# Print a summary of the model
print(model)
# Optimize using Adam optimizer from `torch.optim`
optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1)
criterion = GLVQLoss(squashing="identity", beta=10)
# Training loop
for epoch in range(150):
epoch_loss = 0.0 # zero-out epoch loss
optimizer.zero_grad() # zero-out gradients
for xb, yb in train_loader:
# Compute loss
distances, plabels = model(xb)
loss = criterion([distances, plabels], yb)
epoch_loss += loss.item()
# Backprop
loss.backward()
# Take a gradient descent step
optimizer.step()
scheduler.step()
lr = optimizer.param_groups[0]["lr"]
print(f"Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}")
# Get the omega matrix form the model
omega = model.omega.weight.data.numpy().T
# Visualize the lambda matrix
title = "Lambda Matrix Visualization"
fig = plt.figure(title)
ax = fig.gca()
ax.set_title(title)
im = ax.imshow(omega.dot(omega.T), cmap="viridis")
plt.show()
# Get the prototypes form the model
protos = model.p1.prototypes.data.numpy()
plabels = model.p1.prototype_labels
# Visualize the prototypes
title = "Tecator Prototypes"
fig = plt.figure(title)
ax = fig.gca()
ax.set_title(title)
ax.set_xlabel("Spectral frequencies")
ax.set_ylabel("Absorption")
clabels = ["Class 0 - Low fat", "Class 1 - High fat"]
handles, colors = get_legend_handles(clabels, marker="line", zero_indexed=True)
for x, y in zip(protos, plabels):
ax.plot(x, c=colors[int(y)])
ax.legend(handles, clabels)
plt.show()

162
examples/gtlvq_mnist.py Normal file
View File

@@ -0,0 +1,162 @@
"""
ProtoTorch GTLVQ example using MNIST data.
The GTLVQ is placed as an classification model on
top of a CNN, considered as featurer extractor.
Initialization of subpsace and prototypes in
Siamnese fashion
For more info about GTLVQ see:
DOI:10.1109/IJCNN.2016.7727534
"""
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from prototorch.modules.losses import GLVQLoss
from prototorch.functions.helper import calculate_prototype_accuracy
from prototorch.modules.models import GTLVQ
# Parameters and options
n_epochs = 50
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.1
momentum = 0.5
log_interval = 10
cuda = "cuda:1"
random_seed = 1
device = torch.device(cuda if torch.cuda.is_available() else 'cpu')
# Configures reproducability
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# Prepare and preprocess the data
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
'./files/',
train=True,
download=True,
transform=torchvision.transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))])),
batch_size=batch_size_train,
shuffle=True)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
'./files/',
train=False,
download=True,
transform=torchvision.transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))])),
batch_size=batch_size_test,
shuffle=True)
# Define the GLVQ model plus appropriate feature extractor
class CNNGTLVQ(torch.nn.Module):
def __init__(
self,
num_classes,
subspace_data,
prototype_data,
tangent_projection_type="local",
prototypes_per_class=2,
bottleneck_dim=128,
):
super(CNNGTLVQ, self).__init__()
#Feature Extractor - Simple CNN
self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
nn.MaxPool2d(2), nn.Dropout(0.25),
nn.Flatten(), nn.Linear(9216, bottleneck_dim),
nn.Dropout(0.5), nn.LeakyReLU(),
nn.LayerNorm(bottleneck_dim))
# Forward pass of subspace and prototype initialization data through feature extractor
subspace_data = self.fe(subspace_data)
prototype_data[0] = self.fe(prototype_data[0])
# Initialization of GTLVQ
self.gtlvq = GTLVQ(num_classes,
subspace_data,
prototype_data,
tangent_projection_type=tangent_projection_type,
feature_dim=bottleneck_dim,
prototypes_per_class=prototypes_per_class)
def forward(self, x):
# Feature Extraction
x = self.fe(x)
# GTLVQ Forward pass
dis = self.gtlvq(x)
return dis
# Get init data
subspace_data = torch.cat(
[next(iter(train_loader))[0],
next(iter(test_loader))[0]])
prototype_data = next(iter(train_loader))
# Build the CNN GTLVQ model
model = CNNGTLVQ(10,
subspace_data,
prototype_data,
tangent_projection_type="local",
bottleneck_dim=128).to(device)
# Optimize using SGD optimizer from `torch.optim`
optimizer = torch.optim.Adam([{
'params': model.fe.parameters()
}, {
'params': model.gtlvq.parameters()
}],
lr=learning_rate)
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
# Training loop
for epoch in range(n_epochs):
for batch_idx, (x_train, y_train) in enumerate(train_loader):
model.train()
x_train, y_train = x_train.to(device), y_train.to(device)
optimizer.zero_grad()
distances = model(x_train)
plabels = model.gtlvq.cls.prototype_labels.to(device)
# Compute loss.
loss = criterion([distances, plabels], y_train)
loss.backward()
optimizer.step()
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
model.gtlvq.orthogonalize_subspace()
if batch_idx % log_interval == 0:
acc = calculate_prototype_accuracy(distances, y_train, plabels)
print(
f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
Train Acc: {acc.item():02.02f}')
# Test
with torch.no_grad():
model.eval()
correct = 0
total = 0
for x_test, y_test in test_loader:
x_test, y_test = x_test.to(device), y_test.to(device)
test_distances = model(torch.tensor(x_test))
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
i = torch.argmin(test_distances, 1)
correct += torch.sum(y_test == test_plabels[i])
total += y_test.size(0)
print('Accuracy of the network on the test images: %d %%' %
(torch.true_divide(correct, total) * 100))
# Save the model
PATH = './glvq_mnist_model.pth'
torch.save(model.state_dict(), PATH)

106
examples/lgmlvq_iris.py Normal file
View File

@@ -0,0 +1,106 @@
"""ProtoTorch LGMLVQ example using 2D Iris data."""
import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
from prototorch.functions.competitions import stratified_min
from prototorch.functions.distances import lomega_distance
from prototorch.functions.init import eye_
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D
# Prepare training data
x_train, y_train = load_iris(True)
x_train = x_train[:, [0, 2]]
# Define the model
class Model(torch.nn.Module):
def __init__(self):
"""Local-GMLVQ model."""
super().__init__()
self.p1 = Prototypes1D(input_dim=2,
prototype_distribution=[1, 2, 2],
prototype_initializer="stratified_random",
data=[x_train, y_train])
omegas = torch.zeros(5, 2, 2)
self.omegas = torch.nn.Parameter(omegas)
eye_(self.omegas)
def forward(self, x):
protos = self.p1.prototypes
plabels = self.p1.prototype_labels
omegas = self.omegas
dis = lomega_distance(x, protos, omegas)
return dis, plabels
# Build the model
model = Model()
# Optimize using Adam optimizer from `torch.optim`
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train)
# Training loop
title = "Prototype Visualization"
fig = plt.figure(title)
for epoch in range(100):
# Compute loss
dis, plabels = model(x_in)
loss = criterion([dis, plabels], y_in)
y_pred = np.argmin(stratified_min(dis, plabels).detach().numpy(), axis=1)
acc = accuracy_score(y_train, y_pred)
log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} "
log_string += f"Acc: {acc * 100:05.02f}%"
print(log_string)
# Take a gradient descent step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Get the prototypes form the model
protos = model.p1.prototypes.data.numpy()
# Visualize the data and the prototypes
ax = fig.gca()
ax.cla()
ax.set_title(title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
ax.scatter(protos[:, 0],
protos[:, 1],
c=plabels,
cmap=cmap,
edgecolor='k',
marker='D',
s=50)
# Paint decision regions
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
d, plabels = model(torch.Tensor(mesh_input))
y_pred = np.argmin(stratified_min(d, plabels).detach().numpy(), axis=1)
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)

View File

@@ -1 +1,46 @@
__version__ = '0.1.0-dev0'
"""ProtoTorch package."""
# #############################################
# Core Setup
# #############################################
__version__ = "0.3.0-dev0"
from prototorch import datasets, functions, modules
__all_core__ = [
"datasets",
"functions",
"modules",
]
# #############################################
# Plugin Loader
# #############################################
import pkgutil
import pkg_resources
__path__ = pkgutil.extend_path(__path__, __name__)
def discover_plugins():
return {
entry_point.name: entry_point.load()
for entry_point in pkg_resources.iter_entry_points("prototorch.plugins")
}
discovered_plugins = discover_plugins()
locals().update(discovered_plugins)
# Generate combines __version__ and __all__
version_plugins = "\n".join(
[
"- " + name + ": v" + plugin.__version__
for name, plugin in discovered_plugins.items()
]
)
if version_plugins != "":
version_plugins = "\nPlugins: \n" + version_plugins
version = "core: v" + __version__ + version_plugins
__all__ = __all_core__ + list(discovered_plugins.keys())

View File

@@ -0,0 +1,7 @@
"""ProtoTorch datasets."""
from .tecator import Tecator
__all__ = [
'Tecator',
]

View File

@@ -0,0 +1,90 @@
"""ProtoTorch abstract dataset classes.
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
For the original code, see:
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
"""
import os
import torch
class Dataset(torch.utils.data.Dataset):
"""Abstract dataset class to be inherited."""
_repr_indent = 2
def __init__(self, root):
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class ProtoDataset(Dataset):
"""Abstract dataset class to be inherited."""
training_file = "training.pt"
test_file = "test.pt"
def __init__(self, root, train=True, download=True, verbose=True):
super().__init__(root)
self.train = train # training set or test set
self.verbose = verbose
if download:
self._download()
if not self._check_exists():
raise RuntimeError(
"Dataset not found. " "You can use download=True to download it"
)
data_file = self.training_file if self.train else self.test_file
self.data, self.targets = torch.load(
os.path.join(self.processed_folder, data_file)
)
@property
def raw_folder(self):
return os.path.join(self.root, self.__class__.__name__, "raw")
@property
def processed_folder(self):
return os.path.join(self.root, self.__class__.__name__, "processed")
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self):
return os.path.exists(
os.path.join(self.processed_folder, self.training_file)
) and os.path.exists(os.path.join(self.processed_folder, self.test_file))
def __repr__(self):
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)
def extra_repr(self):
return f"Split: {'Train' if self.train is True else 'Test'}"
def __len__(self):
return len(self.data)
def _download(self):
raise NotImplementedError

View File

@@ -0,0 +1,103 @@
"""Tecator dataset for classification.
URL:
http://lib.stat.cmu.edu/datasets/tecator
LICENCE / TERMS / COPYRIGHT:
This is the Tecator data set: The task is to predict the fat content
of a meat sample on the basis of its near infrared absorbance spectrum.
-------------------------------------------------------------------------
1. Statement of permission from Tecator (the original data source)
These data are recorded on a Tecator Infratec Food and Feed Analyzer
working in the wavelength range 850 - 1050 nm by the Near Infrared
Transmission (NIT) principle. Each sample contains finely chopped pure
meat with different moisture, fat and protein contents.
If results from these data are used in a publication we want you to
mention the instrument and company name (Tecator) in the publication.
In addition, please send a preprint of your article to
Karin Thente, Tecator AB,
Box 70, S-263 21 Hoganas, Sweden
The data are available in the public domain with no responsability from
the original data source. The data can be redistributed as long as this
permission note is attached.
For more information about the instrument - call Perstorp Analytical's
representative in your area.
Description:
For each meat sample the data consists of a 100 channel spectrum of
absorbances and the contents of moisture (water), fat and protein.
The absorbance is -log10 of the transmittance
measured by the spectrometer. The three contents, measured in percent,
are determined by analytic chemistry.
"""
import os
import numpy as np
import torch
from torchvision.datasets.utils import download_file_from_google_drive
from prototorch.datasets.abstract import ProtoDataset
class Tecator(ProtoDataset):
"""
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__
for classification.
"""
_resources = [
("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0", "ba5607c580d0f91bb27dc29d13c2f8df"),
] # (google_storage_id, md5hash)
classes = ["0 - low_fat", "1 - high_fat"]
def __getitem__(self, index):
img, target = self.data[index], int(self.targets[index])
return img, target
def _download(self):
"""Download the data if it doesn't exist in already."""
if self._check_exists():
return
if self.verbose:
print("Making directories...")
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
if self.verbose:
print("Downloading...")
for fileid, md5 in self._resources:
filename = "tecator.npz"
download_file_from_google_drive(
fileid, root=self.raw_folder, filename=filename, md5=md5
)
if self.verbose:
print("Processing...")
with np.load(
os.path.join(self.raw_folder, "tecator.npz"), allow_pickle=False
) as f:
x_train, y_train = f["x_train"], f["y_train"]
x_test, y_test = f["x_test"], f["y_test"]
training_set = [
torch.tensor(x_train, dtype=torch.float32),
torch.tensor(y_train),
]
test_set = [
torch.tensor(x_test, dtype=torch.float32),
torch.tensor(y_test),
]
with open(os.path.join(self.processed_folder, self.training_file), "wb") as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self.test_file), "wb") as f:
torch.save(test_set, f)
if self.verbose:
print("Done!")

View File

@@ -0,0 +1,12 @@
"""ProtoTorch functions."""
from .activations import identity, sigmoid_beta, swish_beta
from .competitions import knnc, wtac
__all__ = [
'identity',
'sigmoid_beta',
'swish_beta',
'knnc',
'wtac',
]

View File

@@ -5,44 +5,60 @@ import torch
ACTIVATIONS = dict()
def register_activation(func):
ACTIVATIONS[func.__name__] = func
return func
# def register_activation(scriptf):
# ACTIVATIONS[scriptf.name] = scriptf
# return scriptf
def register_activation(function):
"""Add the activation function to the registry."""
ACTIVATIONS[function.__name__] = function
return function
@register_activation
def identity(input, **kwargs):
""":math:`f(x) = x`"""
return input
# @torch.jit.script
def identity(x, beta=torch.tensor(0)):
"""Identity activation function.
Definition:
:math:`f(x) = x`
"""
return x
@register_activation
def sigmoid_beta(input, beta=10):
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
# @torch.jit.script
def sigmoid_beta(x, beta=torch.tensor(10)):
r"""Sigmoid activation function with scaling.
Definition:
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
Keyword Arguments:
beta (float): Parameter :math:`\\beta`
beta (`torch.tensor`): Scaling parameter :math:`\beta`
"""
out = torch.reciprocal(1.0 + torch.exp(-beta * input))
out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * x))
return out
@register_activation
def swish_beta(input, beta=10):
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
# @torch.jit.script
def swish_beta(x, beta=torch.tensor(10)):
r"""Swish activation function with scaling.
Definition:
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
Keyword Arguments:
beta (float): Parameter :math:`\\beta`
beta (`torch.tensor`): Scaling parameter :math:`\beta`
"""
out = input * sigmoid_beta(input, beta=beta)
out = x * sigmoid_beta(x, beta=beta)
return out
def get_activation(funcname):
"""Deserialize the activation function."""
if callable(funcname):
return funcname
else:
if funcname in ACTIVATIONS:
return ACTIVATIONS.get(funcname)
else:
raise NameError(f'Activation {funcname} was not found.')
if funcname in ACTIVATIONS:
return ACTIVATIONS.get(funcname)
raise NameError(f'Activation {funcname} was not found.')

View File

@@ -3,13 +3,43 @@
import torch
# @torch.jit.script
def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0)
nclasses = clabels.size()[0]
if distances.size()[1] == nclasses:
# skip if only one prototype per class
return distances
batch_size = distances.size()[0]
winning_distances = torch.zeros(nclasses, batch_size)
inf = torch.full_like(distances.T, fill_value=float('inf'))
# distances_to_wpluses = torch.where(matcher, distances, inf)
for i, cl in enumerate(clabels):
# cdists = distances.T[labels == cl]
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
if labels.ndim == 2:
# if the labels are one-hot vectors
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
cdists = torch.where(matcher, distances.T, inf).T
winning_distances[i] = torch.min(cdists, dim=1,
keepdim=True).values.squeeze()
if labels.ndim == 2:
# Transpose to return with `batch_size` first and
# reverse the columns to fix the ordering of the classes
return torch.flip(winning_distances.T, dims=(1, ))
return winning_distances.T # return with `batch_size` first
# @torch.jit.script
def wtac(distances, labels):
winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze()
return winning_labels
# @torch.jit.script
def knnc(distances, labels, k):
winning_indices = torch.topk(-distances, k=k, dim=1).indices
winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
winning_labels = labels[winning_indices].squeeze()
return winning_labels

View File

@@ -1,13 +1,24 @@
"""ProtoTorch distance functions."""
import torch
from prototorch.functions.helper import (
equal_int_shape,
_int_and_mixed_shape,
_check_shapes,
)
import numpy as np
def squared_euclidean_distance(x, y):
"""Compute the squared Euclidean distance between :math:`x` and :math:`y`.
r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
Expected dimension of x is 2.
Expected dimension of y is 2.
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
**Alias:**
``prototorch.functions.distances.sed``
"""
expanded_x = x.unsqueeze(dim=1)
batchwise_difference = y - expanded_x
@@ -17,10 +28,15 @@ def squared_euclidean_distance(x, y):
def euclidean_distance(x, y):
"""Compute the Euclidean distance between :math:`x` and :math:`y`.
r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
Expected dimension of x is 2.
Expected dimension of y is 2.
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
:param `torch.tensor` x: Input Tensor of shape :math:`X \times N`
:param `torch.tensor` y: Input Tensor of shape :math:`Y \times N`
:returns: Distance Tensor of shape :math:`X \times Y`
:rtype: `torch.tensor`
"""
distances_raised = squared_euclidean_distance(x, y)
distances = torch.sqrt(distances_raised)
@@ -28,30 +44,30 @@ def euclidean_distance(x, y):
def lpnorm_distance(x, y, p):
"""Compute :math:`{\\langle x, y \\rangle}_p`.
r"""
Calculates the lp-norm between :math:`\bm x` and :math:`\bm y`.
Also known as Minkowski distance.
Expected dimension of x is 2.
Expected dimension of y is 2.
Compute :math:`{\| \bm x - \bm y \|}_p`.
Calls ``torch.cdist``
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
:param p: p parameter of the lp norm
"""
# # DEPRECATED in favor of torch.cdist
# expanded_x = x.unsqueeze(dim=1)
# batchwise_difference = y - expanded_x
# differences_raised = torch.pow(batchwise_difference, p)
# distances_raised = torch.sum(differences_raised, axis=2)
# distances = torch.pow(distances_raised, 1.0 / p)
# return distances
distances = torch.cdist(x, y, p=p)
return distances
def omega_distance(x, y, omega):
"""Omega distance.
r"""Omega distance.
Compute :math:`{\\langle \\Omega x, \\Omega y \\rangle}_p`
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
Expected dimension of x is 2.
Expected dimension of y is 2.
Expected dimension of omega is 2.
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
:param `torch.tensor` omega: Two dimensional matrix
"""
projected_x = x @ omega
projected_y = y @ omega
@@ -60,19 +76,193 @@ def omega_distance(x, y, omega):
def lomega_distance(x, y, omegas):
"""Localized Omega distance.
r"""Localized Omega distance.
Compute :math:`{\\langle \\Omega_k x, \\Omega_k y_k \\rangle}_p`
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
Expected dimension of x is 2.
Expected dimension of y is 2.
Expected dimension of omegas is 3.
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
:param `torch.tensor` omegas: Three dimensional matrix
"""
projected_x = x @ omegas
projected_y = torch.diagonal(y @ omegas).T
expanded_y = torch.unsqueeze(projected_y, dim=1)
batchwise_difference = expanded_y - projected_x
differences_squared = batchwise_difference**2
differences_squared = batchwise_difference ** 2
distances = torch.sum(differences_squared, dim=2)
distances = distances.permute(1, 0)
return distances
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
r"""Computes an euclidean distances matrix given two distinct vectors.
last dimension must be the vector dimension!
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
- ``x.shape = (number_of_x_vectors, vector_dim)``
- ``y.shape = (number_of_y_vectors, vector_dim)``
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
"""
for tensor in [x, y]:
if tensor.ndim != 2:
raise ValueError(
"The tensor dimension must be two. You provide: tensor.ndim="
+ str(tensor.ndim)
+ "."
)
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
raise ValueError(
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
+ str(tuple(x.shape)[1])
+ " and tuple(y.shape)(y)[1]="
+ str(tuple(y.shape)[1])
+ "."
)
y = torch.transpose(y)
diss = (
torch.sum(x ** 2, axis=1, keepdims=True)
- 2 * torch.dot(x, y)
+ torch.sum(y ** 2, axis=0, keepdims=True)
)
if not squared:
if epsilon == 0:
diss = torch.sqrt(diss)
else:
diss = torch.sqrt(torch.max(diss, epsilon))
return diss
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
For more info about Tangen distances see
DOI:10.1109/IJCNN.2016.7727534.
The subspaces is always assumed as transposed and must be orthogonal!
For local non sparse signals subspaces must be provided!
- shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
- shape(protos): proto_number x dim1 x dim2 x ... x dimN
- shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
subspace should be orthogonalized
Pytorch implementation of Sascha Saralajew's tensorflow code.
Translation by Christoph Raab
"""
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
subspace_int_shape = tuple(subspaces.shape)
# check if the shapes are correct
_check_shapes(signal_int_shape, proto_int_shape)
atom_axes = list(range(3, len(signal_int_shape)))
# for sparse signals, we use the memory efficient implementation
if signal_int_shape[1] == 1:
signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
if len(atom_axes) > 1:
protos = torch.reshape(protos, [proto_shape[0], -1])
if subspaces.ndim == 2:
# clean solution without map if the matrix_scope is global
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
subspaces, torch.transpose(subspaces)
)
projected_signals = torch.dot(signals, projectors)
projected_protos = torch.dot(protos, projectors)
diss = euclidean_distance_matrix(
projected_signals, projected_protos, squared=squared, epsilon=epsilon
)
diss = torch.reshape(
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
)
return torch.permute(diss, [0, 2, 1])
else:
# no solution without map possible --> memory efficient but slow!
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
subspaces, subspaces
) # K.batch_dot(subspaces, subspaces, [2, 2])
projected_protos = (
protos @ subspaces
).T # K.batch_dot(projectors, protos, [1, 1]))
def projected_norm(projector):
return torch.sum(torch.dot(signals, projector) ** 2, axis=1)
diss = (
torch.transpose(map(projected_norm, projectors))
- 2 * torch.dot(signals, projected_protos)
+ torch.sum(projected_protos ** 2, axis=0, keepdims=True)
)
if not squared:
if epsilon == 0:
diss = torch.sqrt(diss)
else:
diss = torch.sqrt(torch.max(diss, epsilon))
diss = torch.reshape(
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
)
return torch.permute(diss, [0, 2, 1])
else:
signals = signals.permute([0, 2, 1] + atom_axes)
diff = signals - protos
# global tangent space
if subspaces.ndim == 2:
# Scope Projectors
projectors = subspaces #
# Scope: Tangentspace Projections
diff = torch.reshape(
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
)
projected_diff = diff @ projectors
projected_diff = torch.reshape(
projected_diff,
(signal_shape[0], signal_shape[2], signal_shape[1]) + signal_shape[3:],
)
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([0, 2, 1])
# local tangent spaces
else:
# Scope: Calculate Projectors
projectors = subspaces
# Scope: Tangentspace Projections
diff = torch.reshape(
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
)
diff = diff.permute([1, 0, 2])
projected_diff = torch.bmm(diff, projectors)
projected_diff = torch.reshape(
projected_diff,
(signal_shape[1], signal_shape[0], signal_shape[2]) + signal_shape[3:],
)
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1)
# Aliases
sed = squared_euclidean_distance

View File

@@ -0,0 +1,89 @@
import torch
def calculate_prototype_accuracy(y_pred, y_true, plabels):
"""Computes the accuracy of a prototype based model.
via Winner-Takes-All rule.
Requirement:
y_pred.shape == y_true.shape
unique(y_pred) in plabels
"""
with torch.no_grad():
idx = torch.argmin(y_pred, axis=1)
return torch.true_divide(torch.sum(y_true == plabels[idx]),
len(y_pred)) * 100
def predict_label(y_pred, plabels):
r""" Predicts labels given a prediction of a prototype based model.
"""
with torch.no_grad():
return plabels[torch.argmin(y_pred, 1)]
def mixed_shape(inputs):
if not torch.is_tensor(inputs):
raise ValueError('Input must be a tensor.')
else:
int_shape = list(inputs.shape)
# sometimes int_shape returns mixed integer types
int_shape = [int(i) if i is not None else i for i in int_shape]
tensor_shape = inputs.shape
for i, s in enumerate(int_shape):
if s is None:
int_shape[i] = tensor_shape[i]
return tuple(int_shape)
def equal_int_shape(shape_1, shape_2):
if not isinstance(shape_1,
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
raise ValueError('Input shapes must list or tuple.')
for shape in [shape_1, shape_2]:
if not all([isinstance(x, int) or x is None for x in shape]):
raise ValueError(
'Input shapes must be list or tuple of int and None values.')
if len(shape_1) != len(shape_2):
return False
else:
for axis, value in enumerate(shape_1):
if value is not None and shape_2[axis] not in {value, None}:
return False
return True
def _check_shapes(signal_int_shape, proto_int_shape):
if len(signal_int_shape) < 4:
raise ValueError(
"The number of signal dimensions must be >=4. You provide: " +
str(len(signal_int_shape)))
if len(proto_int_shape) < 2:
raise ValueError(
"The number of proto dimensions must be >=2. You provide: " +
str(len(proto_int_shape)))
if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]):
raise ValueError(
"The atom shape of signals must be equal protos. You provide: signals.shape[3:]="
+ str(signal_int_shape[3:]) + " != protos.shape[1:]=" +
str(proto_int_shape[1:]))
# not a sparse signal
if signal_int_shape[1] != 1:
if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]):
raise ValueError(
"If the signal is not sparse, the number of prototypes must be equal in signals and "
"protos. You provide: " + str(signal_int_shape[1]) + " != " +
str(proto_int_shape[0]))
return True
def _int_and_mixed_shape(tensor):
shape = mixed_shape(tensor)
int_shape = tuple([i if isinstance(i, int) else None for i in shape])
return shape, int_shape

View File

@@ -7,87 +7,101 @@ import torch
INITIALIZERS = dict()
def register_initializer(func):
INITIALIZERS[func.__name__] = func
return func
def register_initializer(function):
"""Add the initializer to the registry."""
INITIALIZERS[function.__name__] = function
return function
def labels_from(distribution):
def labels_from(distribution, one_hot=True):
"""Takes a distribution tensor and returns a labels tensor."""
nclasses = distribution.shape[0]
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
# labels = [l for cl in llist for l in cl] # flatten the list of lists
labels = list(chain(*llist)) # flatten using itertools.chain
return torch.tensor(labels, requires_grad=False)
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
plabels = torch.tensor(flat_llist, requires_grad=False)
if one_hot:
return torch.eye(nclasses)[plabels]
return plabels
@register_initializer
def ones(x_train, y_train, prototype_distribution):
def ones(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution)
protos = torch.ones(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution)
plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels
@register_initializer
def zeros(x_train, y_train, prototype_distribution):
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution)
protos = torch.zeros(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution)
plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels
@register_initializer
def rand(x_train, y_train, prototype_distribution):
def rand(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution)
protos = torch.rand(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution)
plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels
@register_initializer
def randn(x_train, y_train, prototype_distribution):
def randn(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution)
protos = torch.randn(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution)
plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels
@register_initializer
def stratified_mean(x_train, y_train, prototype_distribution):
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution)
pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim)
plabels = labels_from(prototype_distribution)
for i, l in enumerate(plabels):
xl = x_train[y_train == l]
plabels = labels_from(prototype_distribution, one_hot)
for i, label in enumerate(plabels):
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
if one_hot:
nclasses = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
xl = x_train[matcher]
mean_xl = torch.mean(xl, dim=0)
protos[i] = mean_xl
plabels = labels_from(prototype_distribution, one_hot=one_hot)
return protos, plabels
@register_initializer
def stratified_random(x_train, y_train, prototype_distribution):
gen = torch.manual_seed(torch.initial_seed())
def stratified_random(x_train,
y_train,
prototype_distribution,
one_hot=True,
epsilon=1e-7):
nprotos = torch.sum(prototype_distribution)
pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim)
plabels = labels_from(prototype_distribution)
for i, l in enumerate(plabels):
xl = x_train[y_train == l]
rand_index = torch.zeros(1).long().random_(0,
xl.shape[1] - 1,
generator=gen)
plabels = labels_from(prototype_distribution, one_hot)
for i, label in enumerate(plabels):
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
if one_hot:
nclasses = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
xl = x_train[matcher]
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
random_xl = xl[rand_index]
protos[i] = random_xl
protos[i] = random_xl + epsilon
plabels = labels_from(prototype_distribution, one_hot=one_hot)
return protos, plabels
def get_initializer(funcname):
"""Deserialize the initializer."""
if callable(funcname):
return funcname
else:
if funcname in INITIALIZERS:
return INITIALIZERS.get(funcname)
else:
raise NameError(f'Initializer {funcname} was not found.')
if funcname in INITIALIZERS:
return INITIALIZERS.get(funcname)
raise NameError(f'Initializer {funcname} was not found.')

View File

@@ -3,23 +3,24 @@
import torch
def glvq_loss(distances, target_labels, prototype_labels):
"""GLVQ loss function with support for one-hot labels."""
matcher = torch.eq(target_labels.unsqueeze(dim=1), prototype_labels)
if prototype_labels.ndim == 2:
def _get_dp_dm(distances, targets, plabels):
matcher = torch.eq(targets.unsqueeze(dim=1), plabels)
if plabels.ndim == 2:
# if the labels are one-hot vectors
nclasses = target_labels.size()[1]
nclasses = targets.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
not_matcher = torch.bitwise_not(matcher)
dplus_criterion = distances * matcher > 0.0
dminus_criterion = distances * not_matcher > 0.0
inf = torch.full_like(distances, fill_value=float('inf'))
distances_to_wpluses = torch.where(dplus_criterion, distances, inf)
distances_to_wminuses = torch.where(dminus_criterion, distances, inf)
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
d_matching = torch.where(matcher, distances, inf)
d_unmatching = torch.where(not_matcher, distances, inf)
dp = torch.min(d_matching, dim=1, keepdim=True).values
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
return dp, dm
mu = (dpluses - dminuses) / (dpluses + dminuses)
def glvq_loss(distances, target_labels, prototype_labels):
"""GLVQ loss function with support for one-hot labels."""
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = (dp - dm) / (dp + dm)
return mu

View File

@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
import torch
def orthogonalization(tensors):
r""" Orthogonalization of a given tensor via polar decomposition.
"""
u, _, v = torch.svd(tensors, compute_uv=True)
u_shape = tuple(list(u.shape))
v_shape = tuple(list(v.shape))
# reshape to (num x N x M)
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
out = u @ v.permute([0, 2, 1])
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
return out
def trace_normalization(tensors):
r""" Trace normalization
"""
epsilon = torch.tensor([1e-10], dtype=torch.float64)
# Scope trace_normalization
constant = torch.trace(tensors)
if epsilon != 0:
constant = torch.max(constant, epsilon)
return tensors / constant

View File

@@ -0,0 +1,7 @@
"""ProtoTorch modules."""
from .prototypes import Prototypes1D
__all__ = [
'Prototypes1D',
]

View File

@@ -7,15 +7,14 @@ from prototorch.functions.losses import glvq_loss
class GLVQLoss(torch.nn.Module):
"""GLVQ Loss."""
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
super().__init__(**kwargs)
self.margin = margin
self.squashing = get_activation(squashing)
self.beta = beta
self.beta = torch.tensor(beta)
def forward(self, outputs, targets):
distances, plabels = outputs
mu = glvq_loss(distances, targets, plabels)
mu = glvq_loss(distances, targets, prototype_labels=plabels)
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
return torch.sum(batch_loss, dim=0)

View File

@@ -0,0 +1,190 @@
from torch import nn
import torch
from prototorch.modules.prototypes import Prototypes1D
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization
Parameters
----------
num_classes: int
Number of classes of the given classification problem.
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
Subspace data for the point approximation, required
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
prototype data for initalization of the prototypes used in GTLVQ.
subspace_size: int (default=256,optional)
Subspace dimension of the Projectors. Currently only supported
with tagnent_projection_type=global.
tangent_projection_type: string
Specifies the tangent projection type
options: local
local_proj
global
local: computes the tangent distances without emphasizing projected
data. Only distances are available
local_proj: computs tangent distances and returns the projected data
for further use. Be careful: data is repeated by number of prototypes
global: Number of subspaces is set to one and every prototypes
uses the same.
prototypes_per_class: int (default=2,optional)
Number of prototypes per class
feature_dim: int (default=256)
Dimensionality of the feature space specified as integer.
Prototype dimension.
Notes
-----
The GTLVQ [1] is a prototype-based classification learning model. The
GTLVQ uses the Tangent-Distances for a local point approximation
of an assumed data manifold via prototypial representations.
The GTLVQ requires subspace projectors for transforming the data
and prototypes into the affine subspace. Every prototype is
equipped with a specific subpspace and represents a point
approximation of the assumed manifold.
In practice prototypes and data are projected on this manifold
and pairwise euclidean distance computes.
References
----------
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
in classification based on manifolc. models and its relation
to tangent metric learning. In: 2017 International Joint
Conference on Neural Networks (IJCNN).
Bd. 2017-May : IEEE, 2017, S. 17561765
"""
def __init__(
self,
num_classes,
subspace_data=None,
prototype_data=None,
subspace_size=256,
tangent_projection_type='local',
prototypes_per_class=2,
feature_dim=256,
):
super(GTLVQ, self).__init__()
self.num_protos = num_classes * prototypes_per_class
self.subspace_size = feature_dim if subspace_size is None else subspace_size
self.feature_dim = feature_dim
if subspace_data is None:
raise ValueError('Init Data must be specified!')
self.tpt = tangent_projection_type
with torch.no_grad():
if self.tpt == 'local' or self.tpt == 'local_proj':
self.init_local_subspace(subspace_data)
elif self.tpt == 'global':
self.init_gobal_subspace(subspace_data, subspace_size)
else:
self.subspaces = None
# Hypothesis-Margin-Classifier
self.cls = Prototypes1D(input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
nclasses=num_classes,
prototype_initializer='stratified_mean',
data=prototype_data)
def forward(self, x):
# Tangent Projection
if self.tpt == 'local_proj':
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2)
dis, proj_x = self.local_tangent_projection(x_conform)
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
self.feature_dim)
return proj_x, dis
elif self.tpt == "local":
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2)
dis = tangent_distance(x_conform, self.cls.prototypes,
self.subspaces)
elif self.tpt == "gloabl":
dis = self.global_tangent_distances(x)
else:
dis = (x @ self.cls.prototypes.T) / (
torch.norm(x, dim=1, keepdim=True) @ torch.norm(
self.cls.prototypes, dim=1, keepdim=True).T)
return dis
def init_gobal_subspace(self, data, num_subspaces):
_, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = subspace[:, :num_subspaces]
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def init_local_subspace(self, data):
_, _, v = torch.svd(data)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
self.num_protos, 0)
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def global_tangent_distances(self, x):
# Tangent Projection
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
# Euclidean Distance
return euclidean_distance_matrix(x, projected_prototypes)
def local_tangent_projection(self,
signals):
# Note: subspaces is always assumed as transposed and must be orthogonal!
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
# subspace should be orthogonalized
# Origin Source Code
# Origin Author:
protos = self.cls.prototypes
subspaces = self.subspaces
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
_, proto_int_shape = _int_and_mixed_shape(protos)
# check if the shapes are correct
_check_shapes(signal_int_shape, proto_int_shape)
# Tangent Data Projections
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
data = signals.squeeze(2).permute([1, 0, 2])
projected_data = torch.bmm(data, subspaces)
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
diff = projected_data - projected_protos
projected_diff = torch.reshape(
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
signal_shape[3:])
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
def get_parameters(self):
return {
"params": self.cls.prototypes,
}, {
"params": self.subspaces
}
def orthogonalize_subspace(self):
if self.subspaces is not None:
with torch.no_grad():
ortho_subpsaces = orthogonalization(
self.subspaces
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
self.subspaces)
self.subspaces.copy_(ortho_subpsaces)

View File

@@ -1,57 +1,132 @@
"""ProtoTorch prototype modules."""
import warnings
import torch
from prototorch.functions.initializers import get_initializer
class AddPrototypes1D(torch.nn.Module):
class _Prototypes(torch.nn.Module):
"""Abstract prototypes class."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _validate_prototype_distribution(self):
if 0 in self.prototype_distribution:
warnings.warn("Are you sure about the `0` in "
"`prototype_distribution`?")
def extra_repr(self):
return f"prototypes.shape: {tuple(self.prototypes.shape)}"
def forward(self):
return self.prototypes, self.prototype_labels
class Prototypes1D(_Prototypes):
"""Create a learnable set of one-dimensional prototypes.
TODO Complete this doc-string.
"""
def __init__(self,
prototypes_per_class=1,
prototype_initializer="ones",
prototype_distribution=None,
prototype_initializer='ones',
data=None,
dtype=torch.float32,
one_hot_labels=False,
**kwargs):
if data is None:
if 'input_dim' not in kwargs:
raise NameError('`input_dim` required if '
'no `data` is provided.')
if prototype_distribution is not None:
nclasses = sum(prototype_distribution)
else:
if 'nclasses' not in kwargs:
raise NameError('`prototype_distribution` required if '
'both `data` and `nclasses` are not '
'provided.')
nclasses = kwargs.pop('nclasses')
input_dim = kwargs.pop('input_dim')
# input_shape = (input_dim, )
x_train = torch.rand(nclasses, input_dim)
y_train = torch.arange(nclasses)
# Convert tensors to python lists before processing
if prototype_distribution is not None:
if not isinstance(prototype_distribution, list):
prototype_distribution = prototype_distribution.tolist()
else:
x_train, y_train = data
x_train = torch.as_tensor(x_train)
y_train = torch.as_tensor(y_train)
if data is None:
if "input_dim" not in kwargs:
raise NameError("`input_dim` required if "
"no `data` is provided.")
if prototype_distribution:
kwargs_nclasses = sum(prototype_distribution)
else:
if "nclasses" not in kwargs:
raise NameError("`prototype_distribution` required if "
"both `data` and `nclasses` are not "
"provided.")
kwargs_nclasses = kwargs.pop("nclasses")
input_dim = kwargs.pop("input_dim")
if prototype_initializer in [
"stratified_mean", "stratified_random"
]:
warnings.warn(
f"`prototype_initializer`: `{prototype_initializer}` "
"requires `data`, but `data` is not provided. "
"Using randomly generated data instead.")
x_train = torch.rand(kwargs_nclasses, input_dim)
y_train = torch.arange(kwargs_nclasses)
if one_hot_labels:
y_train = torch.eye(kwargs_nclasses)[y_train]
data = [x_train, y_train]
x_train, y_train = data
x_train = torch.as_tensor(x_train).type(dtype)
y_train = torch.as_tensor(y_train).type(torch.int)
nclasses = torch.unique(y_train, dim=-1).shape[-1]
if nclasses == 1:
warnings.warn("Are you sure about having one class only?")
if x_train.ndim != 2:
raise ValueError("`data[0].ndim != 2`.")
if y_train.ndim == 2:
if y_train.shape[1] == 1 and one_hot_labels:
raise ValueError("`one_hot_labels` is set to `True` "
"but target labels are not one-hot-encoded.")
if y_train.shape[1] != 1 and not one_hot_labels:
raise ValueError("`one_hot_labels` is set to `False` "
"but target labels in `data` "
"are one-hot-encoded.")
if y_train.ndim == 1 and one_hot_labels:
raise ValueError("`one_hot_labels` is set to `True` "
"but target labels are not one-hot-encoded.")
# Verify input dimension if `input_dim` is provided
if "input_dim" in kwargs:
input_dim = kwargs.pop("input_dim")
if input_dim != x_train.shape[1]:
raise ValueError(f"Provided `input_dim`={input_dim} does "
"not match data dimension "
f"`data[0].shape[1]`={x_train.shape[1]}")
# Verify the number of classes if `nclasses` is provided
if "nclasses" in kwargs:
kwargs_nclasses = kwargs.pop("nclasses")
if kwargs_nclasses != nclasses:
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
"not match data labels "
"`torch.unique(data[1]).shape[0]`"
f"={nclasses}")
super().__init__(**kwargs)
self.prototypes_per_class = prototypes_per_class
if not prototype_distribution:
prototype_distribution = [prototypes_per_class] * nclasses
with torch.no_grad():
if not prototype_distribution:
num_classes = torch.unique(y_train).shape[0]
self.prototype_distribution = torch.tensor(
[self.prototypes_per_class] * num_classes)
else:
self.prototype_distribution = torch.tensor(
prototype_distribution)
self.prototype_distribution = torch.tensor(prototype_distribution)
self._validate_prototype_distribution()
self.prototype_initializer = get_initializer(prototype_initializer)
prototypes, prototype_labels = self.prototype_initializer(
x_train,
y_train,
prototype_distribution=self.prototype_distribution)
self.prototypes = torch.nn.Parameter(prototypes)
self.prototype_labels = prototype_labels
prototype_distribution=self.prototype_distribution,
one_hot=one_hot_labels,
)
def forward(self):
return self.prototypes, self.prototype_labels
# Register module parameters
self.prototypes = torch.nn.Parameter(prototypes)
self.prototype_labels = torch.nn.Parameter(
prototype_labels.type(dtype)).requires_grad_(False)

View File

@@ -0,0 +1 @@
from .colors import color_scheme, get_legend_handles

View File

@@ -0,0 +1,74 @@
"""ProtoFlow color utilities."""
from matplotlib import cm
from matplotlib.colors import Normalize
from matplotlib.colors import to_hex
from matplotlib.colors import to_rgb
import matplotlib.lines as mlines
def color_scheme(n, cmap="viridis", form="hex", tikz=False,
zero_indexed=False):
"""Return *n* colors from the color scheme.
Arguments:
n (int): number of colors to return
Keyword Arguments:
cmap (str): Name of a matplotlib `colormap\
<https://matplotlib.org/3.1.1/gallery/color/colormap_reference.html>`_.
form (str): Colorformat (supports "hex" and "rgb").
tikz (bool): Output as `TikZ <https://github.com/pgf-tikz/pgf>`_
command.
zero_indexed (bool): Use zero indexing for output array.
Returns:
(list): List of colors
"""
cmap = cm.get_cmap(cmap)
colornorm = Normalize(vmin=1, vmax=n)
hex_map = dict()
rgb_map = dict()
for cl in range(1, n + 1):
if zero_indexed:
hex_map[cl - 1] = to_hex(cmap(colornorm(cl)))
rgb_map[cl - 1] = to_rgb(cmap(colornorm(cl)))
else:
hex_map[cl] = to_hex(cmap(colornorm(cl)))
rgb_map[cl] = to_rgb(cmap(colornorm(cl)))
if tikz:
for k, v in rgb_map.items():
print(f"\\definecolor{{color-{k}}}{{rgb}}{{{v[0]},{v[1]},{v[2]}}}")
if form == "hex":
return hex_map
elif form == "rgb":
return rgb_map
else:
return hex_map
def get_legend_handles(labels, marker="dots", zero_indexed=False):
"""Return matplotlib legend handles and colors."""
handles = list()
n = len(labels)
colors = color_scheme(n,
cmap="viridis",
form="hex",
zero_indexed=zero_indexed)
for label, color in zip(labels, colors.values()):
if marker == "dots":
handle = mlines.Line2D([], [],
color="white",
markerfacecolor=color,
marker="o",
markersize=10,
markeredgecolor="k",
label=label)
else:
handle = mlines.Line2D([], [],
color=color,
marker="",
markersize=15,
label=label)
handles.append(handle)
return handles, colors

5
requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
matplotlib==3.1.2
pytest==5.3.4
requests==2.22.0
codecov==2.0.22
tqdm==4.44.1

117
setup.py
View File

@@ -1,49 +1,82 @@
"""Install ProtoTorch."""
"""
_____ _ _______ _
| __ \ | | |__ __| | |
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|
ProtoTorch Core Package
"""
from setuptools import setup
from setuptools import find_packages
PROJECT_URL = 'https://github.com/si-cim/prototorch'
DOWNLOAD_URL = 'https://github.com/si-cim/prototorch.git'
PROJECT_URL = "https://github.com/si-cim/prototorch"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
with open('README.md', 'r') as fh:
with open("README.md", "r") as fh:
long_description = fh.read()
setup(name='prototorch',
version='0.1.0-dev0',
description='Highly extensible, GPU-supported '
'Learning Vector Quantization (LVQ) toolbox '
'built using PyTorch and its nn API.',
long_description=long_description,
long_description_content_type='text/markdown',
author='Jensun Ravichandran',
author_email='jjensun@gmail.com',
url=PROJECT_URL,
download_url=DOWNLOAD_URL,
license='MIT',
install_requires=[
'torch>=1.3.1',
'torchvision>=0.5.0',
'numpy>=1.9.1',
],
extras_require={
'examples': [
'sklearn',
'matplotlib',
],
'tests': ['pytest'],
},
classifiers=[
'Development Status :: 2 - Pre-Alpha', 'Environment :: Console',
'Intended Audience :: Developers', 'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Operating System :: OS Independent',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules'
],
packages=find_packages())
INSTALL_REQUIRES = [
"torch>=1.3.1",
"torchvision>=0.5.0",
"numpy>=1.9.1",
]
DOCS = [
"recommonmark",
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-katex",
]
DATASETS = [
"requests",
"tqdm",
]
EXAMPLES = [
"sklearn",
"matplotlib",
"torchinfo",
]
TESTS = ["pytest"]
ALL = DOCS + DATASETS + EXAMPLES + TESTS
setup(
name="prototorch",
version="0.3.0-dev0",
description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.",
long_description=long_description,
long_description_content_type="text/markdown",
author="Jensun Ravichandran",
author_email="jjensun@gmail.com",
url=PROJECT_URL,
download_url=DOWNLOAD_URL,
license="MIT",
install_requires=INSTALL_REQUIRES,
extras_require={
"docs": DOCS,
"datasets": DATASETS,
"examples": EXAMPLES,
"tests": TESTS,
"all": ALL,
},
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Environment :: Console",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
packages=find_packages(),
zip_safe=False,
)

0
tests/__init__.py Normal file
View File

95
tests/test_datasets.py Normal file
View File

@@ -0,0 +1,95 @@
"""ProtoTorch datasets test suite."""
import os
import shutil
import unittest
import torch
from prototorch.datasets import abstract, tecator
class TestAbstract(unittest.TestCase):
def test_getitem(self):
with self.assertRaises(NotImplementedError):
abstract.Dataset('./artifacts')[0]
def test_len(self):
with self.assertRaises(NotImplementedError):
len(abstract.Dataset('./artifacts'))
class TestProtoDataset(unittest.TestCase):
def test_getitem(self):
with self.assertRaises(NotImplementedError):
abstract.ProtoDataset('./artifacts')[0]
def test_download(self):
with self.assertRaises(NotImplementedError):
abstract.ProtoDataset('./artifacts').download()
class TestTecator(unittest.TestCase):
def setUp(self):
self.artifacts_dir = './artifacts/Tecator'
self._remove_artifacts()
def _remove_artifacts(self):
if os.path.exists(self.artifacts_dir):
shutil.rmtree(self.artifacts_dir)
def test_download_false(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
self._remove_artifacts()
with self.assertRaises(RuntimeError):
_ = tecator.Tecator(rootdir, download=False)
def test_download_caching(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
_ = tecator.Tecator(rootdir, download=True, verbose=False)
_ = tecator.Tecator(rootdir, download=False, verbose=False)
def test_repr(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
train = tecator.Tecator(rootdir, download=True, verbose=True)
self.assertTrue('Split: Train' in train.__repr__())
def test_download_train(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
train = tecator.Tecator(root=rootdir,
train=True,
download=True,
verbose=False)
train = tecator.Tecator(root=rootdir, download=True, verbose=False)
x_train, y_train = train.data, train.targets
self.assertEqual(x_train.shape[0], 144)
self.assertEqual(y_train.shape[0], 144)
self.assertEqual(x_train.shape[1], 100)
def test_download_test(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
x_test, y_test = test.data, test.targets
self.assertEqual(x_test.shape[0], 71)
self.assertEqual(y_test.shape[0], 71)
self.assertEqual(x_test.shape[1], 100)
def test_class_to_idx(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
_ = test.class_to_idx
def test_getitem(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
x, y = test[0]
self.assertEqual(x.shape[0], 100)
self.assertIsInstance(y, int)
def test_loadable_with_dataloader(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
def tearDown(self):
pass

View File

@@ -6,7 +6,148 @@ import numpy as np
import torch
from prototorch.functions import (activations, competitions, distances,
initializers)
initializers, losses)
class TestActivations(unittest.TestCase):
def setUp(self):
self.flist = ['identity', 'sigmoid_beta', 'swish_beta']
self.x = torch.randn(1024, 1)
def test_registry(self):
self.assertIsNotNone(activations.ACTIVATIONS)
def test_funcname_deserialization(self):
for funcname in self.flist:
f = activations.get_activation(funcname)
iscallable = callable(f)
self.assertTrue(iscallable)
# def test_torch_script(self):
# for funcname in self.flist:
# f = activations.get_activation(funcname)
# self.assertIsInstance(f, torch.jit.ScriptFunction)
def test_callable_deserialization(self):
def dummy(x, **kwargs):
return x
for f in [dummy, lambda x: x]:
f = activations.get_activation(f)
iscallable = callable(f)
self.assertTrue(iscallable)
self.assertEqual(1, f(1))
def test_unknown_deserialization(self):
for funcname in ['blubb', 'foobar']:
with self.assertRaises(NameError):
_ = activations.get_activation(funcname)
def test_identity(self):
actual = activations.identity(self.x)
desired = self.x
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_sigmoid_beta1(self):
actual = activations.sigmoid_beta(self.x, beta=torch.tensor(1))
desired = torch.sigmoid(self.x)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_swish_beta1(self):
actual = activations.swish_beta(self.x, beta=torch.tensor(1))
desired = self.x * torch.sigmoid(self.x)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def tearDown(self):
del self.x
class TestCompetitions(unittest.TestCase):
def setUp(self):
pass
def test_wtac(self):
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
labels = torch.tensor([0, 1, 2, 3])
actual = competitions.wtac(d, labels)
desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_wtac_unequal_dist(self):
d = torch.tensor([[2., 3., 4.], [2., 3., 1.]])
labels = torch.tensor([0, 1, 1])
actual = competitions.wtac(d, labels)
desired = torch.tensor([0, 1])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_wtac_one_hot(self):
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
labels = torch.tensor([[0, 1], [1, 0]])
actual = competitions.wtac(d, labels)
desired = torch.tensor([[0, 1], [1, 0]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_min(self):
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
labels = torch.tensor([0, 0, 1, 2])
actual = competitions.stratified_min(d, labels)
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_min_one_hot(self):
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
labels = torch.tensor([0, 0, 1, 2])
labels = torch.eye(3)[labels]
actual = competitions.stratified_min(d, labels)
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_min_simple(self):
d = torch.tensor([[0., 2., 3.], [8., 0, 1]])
labels = torch.tensor([0, 1, 2])
actual = competitions.stratified_min(d, labels)
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_knnc_k1(self):
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
labels = torch.tensor([0, 1, 2, 3])
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def tearDown(self):
pass
class TestDistances(unittest.TestCase):
@@ -167,103 +308,12 @@ class TestDistances(unittest.TestCase):
del self.x, self.y
class TestActivations(unittest.TestCase):
def setUp(self):
self.x = torch.randn(1024, 1)
def test_registry(self):
self.assertIsNotNone(activations.ACTIVATIONS)
def test_funcname_deserialization(self):
flist = ['identity', 'sigmoid_beta', 'swish_beta']
for funcname in flist:
f = activations.get_activation(funcname)
iscallable = callable(f)
self.assertTrue(iscallable)
def test_callable_deserialization(self):
def dummy(x, **kwargs):
return x
for f in [dummy, lambda x: x]:
f = activations.get_activation(f)
iscallable = callable(f)
self.assertTrue(iscallable)
self.assertEqual(1, f(1))
def test_unknown_deserialization(self):
for funcname in ['blubb', 'foobar']:
with self.assertRaises(NameError):
_ = activations.get_activation(funcname)
def test_identity(self):
actual = activations.identity(self.x)
desired = self.x
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_sigmoid_beta1(self):
actual = activations.sigmoid_beta(self.x, beta=1)
desired = torch.sigmoid(self.x)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_swish_beta1(self):
actual = activations.swish_beta(self.x, beta=1)
desired = self.x * torch.sigmoid(self.x)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def tearDown(self):
del self.x
class TestCompetitions(unittest.TestCase):
def setUp(self):
pass
def test_wtac(self):
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
labels = torch.tensor([0, 1, 2, 3])
actual = competitions.wtac(d, labels)
desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_wtac_one_hot(self):
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
labels = torch.tensor([[0, 1], [1, 0]])
actual = competitions.wtac(d, labels)
desired = torch.tensor([[0, 1], [1, 0]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_knnc_k1(self):
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
labels = torch.tensor([0, 1, 2, 3])
actual = competitions.knnc(d, labels, k=1)
desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def tearDown(self):
pass
class TestInitializers(unittest.TestCase):
def setUp(self):
self.flist = [
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
'stratified_random'
]
self.x = torch.tensor(
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
dtype=torch.float32)
@@ -274,11 +324,7 @@ class TestInitializers(unittest.TestCase):
self.assertIsNotNone(initializers.INITIALIZERS)
def test_funcname_deserialization(self):
flist = [
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
'stratified_random'
]
for funcname in flist:
for funcname in self.flist:
f = initializers.get_initializer(funcname)
iscallable = callable(f)
self.assertTrue(iscallable)
@@ -336,7 +382,7 @@ class TestInitializers(unittest.TestCase):
def test_stratified_mean_equal1(self):
pdist = torch.tensor([1, 1])
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
@@ -345,8 +391,9 @@ class TestInitializers(unittest.TestCase):
def test_stratified_random_equal1(self):
pdist = torch.tensor([1, 1])
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.]])
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
False)
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
@@ -354,7 +401,7 @@ class TestInitializers(unittest.TestCase):
def test_stratified_mean_equal2(self):
pdist = torch.tensor([2, 2])
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.],
[1., 1., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual,
@@ -362,9 +409,20 @@ class TestInitializers(unittest.TestCase):
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_random_equal2(self):
pdist = torch.tensor([2, 2])
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
False)
desired = torch.tensor([[0., -1., -2.], [0., -1., -2.], [0., 0., 0.],
[0., 0., 0.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_mean_unequal(self):
pdist = torch.tensor([1, 3])
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
[1., 1., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual,
@@ -374,14 +432,86 @@ class TestInitializers(unittest.TestCase):
def test_stratified_random_unequal(self):
pdist = torch.tensor([1, 3])
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.], [0., 0., 0.],
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
False)
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_mean_unequal_one_hot(self):
pdist = torch.tensor([1, 3])
y = torch.eye(2)[self.y]
desired1 = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
[1., 1., 1.]])
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
mismatch = np.testing.assert_array_almost_equal(actual1,
desired1,
decimal=5)
mismatch = np.testing.assert_array_almost_equal(actual2,
desired2,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_random_unequal_one_hot(self):
pdist = torch.tensor([1, 3])
y = torch.eye(2)[self.y]
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
desired1 = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]])
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
mismatch = np.testing.assert_array_almost_equal(actual1,
desired1,
decimal=5)
mismatch = np.testing.assert_array_almost_equal(actual2,
desired2,
decimal=5)
self.assertIsNone(mismatch)
def tearDown(self):
del self.x, self.y, self.gen
_ = torch.seed()
class TestLosses(unittest.TestCase):
def setUp(self):
pass
def test_glvq_loss_int_labels(self):
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
labels = torch.tensor([0, 1])
targets = torch.ones(100)
batch_loss = losses.glvq_loss(distances=d,
target_labels=targets,
prototype_labels=labels)
loss_value = torch.sum(batch_loss, dim=0)
self.assertEqual(loss_value, -100)
def test_glvq_loss_one_hot_labels(self):
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
labels = torch.tensor([[0, 1], [1, 0]])
wl = torch.tensor([1, 0])
targets = torch.stack([wl for _ in range(100)], dim=0)
batch_loss = losses.glvq_loss(distances=d,
target_labels=targets,
prototype_labels=labels)
loss_value = torch.sum(batch_loss, dim=0)
self.assertEqual(loss_value, -100)
def test_glvq_loss_one_hot_unequal(self):
dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
d = torch.stack(dlist, dim=1)
labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
wl = torch.tensor([1, 0])
targets = torch.stack([wl for _ in range(100)], dim=0)
batch_loss = losses.glvq_loss(distances=d,
target_labels=targets,
prototype_labels=labels)
loss_value = torch.sum(batch_loss, dim=0)
self.assertEqual(loss_value, -100)
def tearDown(self):
pass

View File

@@ -5,7 +5,7 @@ import unittest
import numpy as np
import torch
from prototorch.modules import prototypes, losses
from prototorch.modules import losses, prototypes
class TestPrototypes(unittest.TestCase):
@@ -16,19 +16,23 @@ class TestPrototypes(unittest.TestCase):
self.y = torch.tensor([0, 0, 1, 1])
self.gen = torch.manual_seed(42)
def test_addprototypes1d_init_without_input_dim(self):
def test_prototypes1d_init_without_input_dim(self):
with self.assertRaises(NameError):
_ = prototypes.AddPrototypes1D(nclasses=1)
_ = prototypes.Prototypes1D(nclasses=2)
def test_addprototypes1d_init_without_nclasses(self):
def test_prototypes1d_init_without_nclasses(self):
with self.assertRaises(NameError):
_ = prototypes.AddPrototypes1D(input_dim=1)
_ = prototypes.Prototypes1D(input_dim=1)
def test_addprototypes1d_init_without_pdist(self):
p1 = prototypes.AddPrototypes1D(input_dim=6,
nclasses=2,
prototypes_per_class=4,
prototype_initializer='ones')
def test_prototypes1d_init_with_nclasses_1(self):
with self.assertWarns(UserWarning):
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
def test_prototypes1d_init_without_pdist(self):
p1 = prototypes.Prototypes1D(input_dim=6,
nclasses=2,
prototypes_per_class=4,
prototype_initializer='ones')
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.ones(8, 6)
@@ -37,11 +41,11 @@ class TestPrototypes(unittest.TestCase):
decimal=5)
self.assertIsNone(mismatch)
def test_addprototypes1d_init_without_data(self):
def test_prototypes1d_init_without_data(self):
pdist = [2, 2]
p1 = prototypes.AddPrototypes1D(input_dim=3,
prototype_distribution=pdist,
prototype_initializer='zeros')
p1 = prototypes.Prototypes1D(input_dim=3,
prototype_distribution=pdist,
prototype_initializer='zeros')
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.zeros(4, 3)
@@ -50,23 +54,20 @@ class TestPrototypes(unittest.TestCase):
decimal=5)
self.assertIsNone(mismatch)
# def test_addprototypes1d_init_torch_pdist(self):
# pdist = torch.tensor([2, 2])
# p1 = prototypes.AddPrototypes1D(input_dim=3,
# prototype_distribution=pdist,
# prototype_initializer='zeros')
# protos = p1.prototypes
# actual = protos.detach().numpy()
# desired = torch.zeros(4, 3)
# mismatch = np.testing.assert_array_almost_equal(actual,
# desired,
# decimal=5)
# self.assertIsNone(mismatch)
def test_prototypes1d_proto_init_without_data(self):
with self.assertWarns(UserWarning):
_ = prototypes.Prototypes1D(
input_dim=3,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=None)
def test_addprototypes1d_init_with_ppc(self):
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
prototypes_per_class=2,
prototype_initializer='zeros')
def test_prototypes1d_init_torch_pdist(self):
pdist = torch.tensor([2, 2])
p1 = prototypes.Prototypes1D(input_dim=3,
prototype_distribution=pdist,
prototype_initializer='zeros')
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.zeros(4, 3)
@@ -75,10 +76,119 @@ class TestPrototypes(unittest.TestCase):
decimal=5)
self.assertIsNone(mismatch)
def test_addprototypes1d_init_with_pdist(self):
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
prototype_distribution=[6, 9],
prototype_initializer='zeros')
def test_prototypes1d_init_without_inputdim_with_data(self):
_ = prototypes.Prototypes1D(nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1.], [0.]], [1, 0]])
def test_prototypes1d_init_with_int_data(self):
_ = prototypes.Prototypes1D(nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1], [0]], [1, 0]])
def test_prototypes1d_init_one_hot_without_data(self):
_ = prototypes.Prototypes1D(input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=None,
one_hot_labels=True)
def test_prototypes1d_init_one_hot_labels_false(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
but the provided `data` has one-hot encoded labels.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=([[0.], [1.]], [[0, 1], [1, 0]]),
one_hot_labels=False)
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
but the provided `data` does not contain one-hot encoded labels.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=([[0.], [1.]], [0, 1]),
one_hot_labels=True)
def test_prototypes1d_init_one_hot_labels_true(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
but the provided `data` contains 2D targets but
does not contain one-hot encoded labels.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=([[0.], [1.]], [[0], [1]]),
one_hot_labels=True)
def test_prototypes1d_init_with_int_dtype(self):
with self.assertRaises(RuntimeError):
_ = prototypes.Prototypes1D(
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1], [0]], [1, 0]],
dtype=torch.int32)
def test_prototypes1d_inputndim_with_data(self):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(input_dim=1,
nclasses=1,
prototypes_per_class=1,
data=[[1.], [1]])
def test_prototypes1d_inputdim_with_data(self):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=2,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1.], [0.]], [1, 0]])
def test_prototypes1d_nclasses_with_data(self):
"""Test ValueError raise if provided `nclasses` is not the same
as the one computed from the provided `data`.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=1,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1.], [2.]], [1, 2]])
def test_prototypes1d_init_with_ppc(self):
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
prototypes_per_class=2,
prototype_initializer='zeros')
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.zeros(4, 3)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_init_with_pdist(self):
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
prototype_distribution=[6, 9],
prototype_initializer='zeros')
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.zeros(15, 3)
@@ -87,14 +197,14 @@ class TestPrototypes(unittest.TestCase):
decimal=5)
self.assertIsNone(mismatch)
def test_addprototypes1d_func_initializer(self):
def test_prototypes1d_func_initializer(self):
def my_initializer(*args, **kwargs):
return torch.full((2, 99), 99), torch.tensor([0, 1])
return torch.full((2, 99), 99.0), torch.tensor([0, 1])
p1 = prototypes.AddPrototypes1D(input_dim=99,
nclasses=2,
prototypes_per_class=1,
prototype_initializer=my_initializer)
p1 = prototypes.Prototypes1D(input_dim=99,
nclasses=2,
prototypes_per_class=1,
prototype_initializer=my_initializer)
protos = p1.prototypes
actual = protos.detach().numpy()
desired = 99 * torch.ones(2, 99)
@@ -103,8 +213,8 @@ class TestPrototypes(unittest.TestCase):
decimal=5)
self.assertIsNone(mismatch)
def test_addprototypes1d_forward(self):
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y])
def test_prototypes1d_forward(self):
p1 = prototypes.Prototypes1D(data=[self.x, self.y])
protos, _ = p1()
actual = protos.detach().numpy()
desired = torch.ones(2, 3)
@@ -113,6 +223,16 @@ class TestPrototypes(unittest.TestCase):
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_dist_validate(self):
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
with self.assertWarns(UserWarning):
_ = p1._validate_prototype_distribution()
def test_prototypes1d_validate_extra_repr_not_empty(self):
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
rep = p1.extra_repr()
self.assertNotEqual(rep, '')
def tearDown(self):
del self.x, self.y, self.gen
_ = torch.seed()
@@ -123,7 +243,37 @@ class TestLosses(unittest.TestCase):
pass
def test_glvqloss_init(self):
_ = losses.GLVQLoss()
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
def test_glvqloss_forward_1ppc(self):
criterion = losses.GLVQLoss(margin=0,
squashing='sigmoid_beta',
beta=100)
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
labels = torch.tensor([0, 1])
targets = torch.ones(100)
outputs = [d, labels]
loss = criterion(outputs, targets)
loss_value = loss.item()
self.assertAlmostEqual(loss_value, 0.0)
def test_glvqloss_forward_2ppc(self):
criterion = losses.GLVQLoss(margin=0,
squashing='sigmoid_beta',
beta=100)
d = torch.stack([
torch.ones(100),
torch.ones(100),
torch.zeros(100),
torch.ones(100)
],
dim=1)
labels = torch.tensor([0, 0, 1, 1])
targets = torch.ones(100)
outputs = [d, labels]
loss = criterion(outputs, targets)
loss_value = loss.item()
self.assertAlmostEqual(loss_value, 0.0)
def tearDown(self):
pass

10
tox.ini
View File

@@ -4,12 +4,12 @@
# and then run "tox" from this directory.
[tox]
envlist = py36
envlist = py36,py37,py38
[testenv]
deps =
numpy
unittest-xml-reporting
pytest
coverage
commands =
python -m xmlrunner -o reports
pip install -e .
coverage run -m pytest