Compare commits
27 Commits
v0.1.1-rc0
...
v0.2.0
Author | SHA1 | Date | |
---|---|---|---|
|
db842b79bb | ||
|
98a8fc52fa | ||
|
6796ec494f | ||
|
429570323e | ||
|
3edb13baf4 | ||
|
42cedbb2b8 | ||
|
2322876eb6 | ||
|
bc7df1059f | ||
|
4c7c9cc34a | ||
|
e39f307194 | ||
|
e2867f696e | ||
|
30dc0ea8b1 | ||
|
895281aabd | ||
|
a55320a65b | ||
|
559f4acc73 | ||
|
9b5bccc39d | ||
|
a8a99f6971 | ||
|
58efa5a4cf | ||
|
9672aab8e2 | ||
|
d5ab9c3771 | ||
|
3e6aa6a20b | ||
|
b138277608 | ||
|
9ccbec52f7 | ||
|
cd652508b9 | ||
|
fa72c7156e | ||
|
6e72b9267a | ||
|
8a4a596035 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.1.1-rc0
|
||||
current_version = 0.2.0
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
||||
@@ -19,3 +19,4 @@ values =
|
||||
|
||||
[bumpversion:file:./prototorch/__init__.py]
|
||||
|
||||
[bumpversion:file:./docs/source/conf.py]
|
||||
|
15
.codacy.yml
Normal file
15
.codacy.yml
Normal file
@@ -0,0 +1,15 @@
|
||||
# To validate the contents of your configuration file
|
||||
# run the following command in the folder where the configuration file is located:
|
||||
# codacy-analysis-cli validate-configuration --directory `pwd`
|
||||
# To analyse, run:
|
||||
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
|
||||
---
|
||||
engines:
|
||||
pylintpython3:
|
||||
exclude_paths:
|
||||
- config/engines.yml
|
||||
remark-lint:
|
||||
exclude_paths:
|
||||
- config/engines.yml
|
||||
exclude_paths:
|
||||
- 'tests/**'
|
27
.readthedocs.yml
Normal file
27
.readthedocs.yml
Normal file
@@ -0,0 +1,27 @@
|
||||
# .readthedocs.yml
|
||||
# Read the Docs configuration file
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
# Required
|
||||
version: 2
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/source/conf.py
|
||||
fail_on_warning: true
|
||||
|
||||
# Build documentation with MkDocs
|
||||
# mkdocs:
|
||||
# configuration: mkdocs.yml
|
||||
|
||||
# Optionally build your docs in additional formats such as PDF and ePub
|
||||
formats: all
|
||||
|
||||
# Optionally set the version of Python and requirements required to build your docs
|
||||
python:
|
||||
version: 3.8
|
||||
install:
|
||||
- method: pip
|
||||
path: .
|
||||
extra_requirements:
|
||||
- all
|
30
.travis.yml
30
.travis.yml
@@ -4,15 +4,33 @@ language: python
|
||||
python: 3.8
|
||||
cache:
|
||||
directories:
|
||||
- ./tests/artifacts
|
||||
|
||||
- "./tests/artifacts"
|
||||
# - "$HOME/.prototorch/datasets"
|
||||
install:
|
||||
- pip install . --progress-bar off
|
||||
- pip install -r requirements.txt
|
||||
- pip install . --progress-bar off
|
||||
- pip install -r requirements.txt
|
||||
|
||||
# Generate code coverage report
|
||||
script:
|
||||
- coverage run -m pytest
|
||||
- coverage run -m pytest
|
||||
|
||||
# Push the results to codecov
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
|
||||
# Publish on PyPI
|
||||
deploy:
|
||||
provider: pypi
|
||||
username: __token__
|
||||
password:
|
||||
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
||||
on:
|
||||
tags: true
|
||||
skip_existing: true
|
||||
|
||||
# The password is encrypted with:
|
||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
||||
# See https://docs.travis-ci.com/user/deployment/pypi and
|
||||
# https://github.com/travis-ci/travis.rb#installation
|
||||
# for more details
|
||||
# Note: The encrypt command does not work well in ZSH.
|
||||
|
@@ -1,6 +1,8 @@
|
||||
include .bumpversion.cfg
|
||||
include LICENSE
|
||||
include tox.ini
|
||||
include *.md
|
||||
include *.txt
|
||||
include *.yml
|
||||
recursive-include docs *.bat
|
||||
recursive-include docs *.png
|
||||
|
48
README.md
48
README.md
@@ -1,7 +1,6 @@
|
||||
# ProtoTorch
|
||||
# ProtoTorch: Prototype Learning in PyTorch
|
||||
|
||||
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge research in
|
||||
prototype-based machine learning algorithms.
|
||||

|
||||
|
||||
[](https://travis-ci.org/si-cim/prototorch)
|
||||

|
||||
@@ -12,53 +11,48 @@ prototype-based machine learning algorithms.
|
||||

|
||||
[](https://github.com/si-cim/prototorch/blob/master/LICENSE)
|
||||
|
||||
*Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow)
|
||||
|
||||
## Description
|
||||
|
||||
This is a Python toolbox brewed at the Mittweida University of Applied Sciences
|
||||
in Germany for bleeding-edge research in Learning Vector Quantization (LVQ)
|
||||
and potentially other prototype-based methods. Although, there are
|
||||
other (perhaps more extensive) LVQ toolboxes available out there, the focus of
|
||||
ProtoTorch is ease-of-use, extensibility and speed.
|
||||
|
||||
Many popular prototype-based Machine Learning (ML) algorithms like K-Nearest
|
||||
Neighbors (KNN), Generalized Learning Vector Quantization (GLVQ) and Generalized
|
||||
Matrix Learning Vector Quantization (GMLVQ) are implemented using the "nn" API
|
||||
provided by PyTorch.
|
||||
in Germany for bleeding-edge research in Prototype-based Machine Learning
|
||||
methods and other interpretable models. The focus of ProtoTorch is ease-of-use,
|
||||
extensibility and speed.
|
||||
|
||||
## Installation
|
||||
|
||||
ProtoTorch can be installed using `pip`.
|
||||
```bash
|
||||
pip install prototorch
|
||||
pip install -U prototorch
|
||||
```
|
||||
To also install the extras, use
|
||||
```bash
|
||||
pip install -U prototorch[all]
|
||||
```
|
||||
|
||||
*Note: If you're using [ZSH](https://www.zsh.org/), the square brackets `[ ]`
|
||||
have to be escaped like so: `\[\]`, making the install command `pip install -U
|
||||
prototorch\[all\]`.*
|
||||
|
||||
To install the bleeding-edge features and improvements:
|
||||
```bash
|
||||
git clone https://github.com/si-cim/prototorch.git
|
||||
git checkout dev
|
||||
cd prototorch
|
||||
pip install -e .
|
||||
pip install -e .[all]
|
||||
```
|
||||
|
||||
## Usage
|
||||
## Documentation
|
||||
|
||||
ProtoTorch is modular. It is very easy to use the modular pieces provided by
|
||||
ProtoTorch, like the layers, losses, callbacks and metrics to build your own
|
||||
prototype-based(instance-based) models. These pieces blend-in seamlessly with
|
||||
numpy and PyTorch to allow you mix and match the modules from ProtoTorch with
|
||||
other PyTorch modules.
|
||||
|
||||
ProtoTorch comes prepackaged with many popular LVQ algorithms in a convenient
|
||||
API, with more algorithms and techniques coming soon. If you would simply like
|
||||
to be able to use those algorithms to train large ML models on a GPU, ProtoTorch
|
||||
lets you do this without requiring a black-belt in high-performance Tensor
|
||||
computation.
|
||||
The documentation is available at <https://www.prototorch.ml/en/latest/>. Should
|
||||
that link not work try <https://prototorch.readthedocs.io/en/latest/>.
|
||||
|
||||
## Bibtex
|
||||
|
||||
If you would like to cite the package, please use this:
|
||||
```bibtex
|
||||
@misc{Ravichandran2020,
|
||||
@misc{Ravichandran2020b,
|
||||
author = {Ravichandran, J},
|
||||
title = {ProtoTorch},
|
||||
year = {2020},
|
||||
|
@@ -1,10 +1,15 @@
|
||||
# ProtoTorch Releases
|
||||
|
||||
## Release 0.2.0
|
||||
|
||||
### Includes
|
||||
- Fixes in example scripts.
|
||||
|
||||
## Release 0.1.1-dev0
|
||||
|
||||
### Includes
|
||||
- Minor bugfixes.
|
||||
- 100% line coverage.
|
||||
- Minor bugfixes.
|
||||
- 100% line coverage.
|
||||
|
||||
## Release 0.1.0-dev0
|
||||
|
||||
|
20
docs/Makefile
Normal file
20
docs/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= python3 -m sphinx
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
35
docs/make.bat
Normal file
35
docs/make.bat
Normal file
@@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
4
docs/requirements.txt
Normal file
4
docs/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
torch==1.6.0
|
||||
matplotlib==3.1.2
|
||||
sphinx_rtd_theme==0.5.0
|
||||
sphinxcontrib-katex==0.6.1
|
BIN
docs/source/_static/img/horizontal-lockup.png
Normal file
BIN
docs/source/_static/img/horizontal-lockup.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 88 KiB |
28
docs/source/api.rst
Normal file
28
docs/source/api.rst
Normal file
@@ -0,0 +1,28 @@
|
||||
.. ProtoFlow API Reference
|
||||
|
||||
ProtoFlow API Reference
|
||||
======================================
|
||||
|
||||
Datasets
|
||||
--------------------------------------
|
||||
.. automodule:: prototorch.datasets
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
Functions
|
||||
--------------------------------------
|
||||
.. automodule:: prototorch.functions
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
Modules
|
||||
--------------------------------------
|
||||
.. automodule:: prototorch.modules
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
Utilities
|
||||
--------------------------------------
|
||||
.. automodule:: prototorch.utils
|
||||
:members:
|
||||
:undoc-members:
|
180
docs/source/conf.py
Normal file
180
docs/source/conf.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file only contains a selection of the most common options. For a full
|
||||
# list see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.abspath("../../"))
|
||||
|
||||
import sphinx_rtd_theme
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "ProtoTorch"
|
||||
copyright = "2021, Jensun Ravichandran"
|
||||
author = "Jensun Ravichandran"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
#
|
||||
release = "0.2.0"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
#
|
||||
needs_sphinx = "1.6"
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named "sphinx.ext.*") or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
"recommonmark",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.doctest",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.todo",
|
||||
"sphinx.ext.coverage",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib.katex",
|
||||
]
|
||||
|
||||
# katex_prerender = True
|
||||
katex_prerender = False
|
||||
|
||||
napoleon_use_ivar = True
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = "index"
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = []
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use. Choose from:
|
||||
# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
|
||||
# "monokai", "perldoc", "pastie", "borland", "trac", "native", "fruity", "bw",
|
||||
# "vim", "vs", "tango", "rrt", "xcode", "igor", "paraiso-light", "paraiso-dark",
|
||||
# "lovelace", "algol", "algol_nu", "arduino", "rainbo w_dash", "abap",
|
||||
# "solarized-dark", "solarized-light", "sas", "stata", "stata-light",
|
||||
# "stata-dark", "inkpot"]
|
||||
pygments_style = "monokai"
|
||||
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
todo_include_todos = True
|
||||
|
||||
# Disable docstring inheritance
|
||||
autodoc_inherit_docstrings = False
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
# https://sphinx-themes.org/
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
html_logo = "_static/img/horizontal-lockup.png"
|
||||
|
||||
html_theme_options = {
|
||||
"logo_only": True,
|
||||
"display_version": True,
|
||||
"prev_next_buttons_location": "bottom",
|
||||
"style_external_links": False,
|
||||
"style_nav_header_background": "#ffffff",
|
||||
# Toc options
|
||||
"collapse_navigation": True,
|
||||
"sticky_navigation": True,
|
||||
"navigation_depth": 4,
|
||||
"includehidden": True,
|
||||
"titles_only": False,
|
||||
}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ["_static"]
|
||||
|
||||
html_css_files = [
|
||||
"https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.css",
|
||||
]
|
||||
|
||||
# -- Options for HTMLHelp output ------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = "protoflowdoc"
|
||||
|
||||
# -- Options for LaTeX output ---------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ("letterpaper" or "a4paper").
|
||||
#
|
||||
# "papersize": "letterpaper",
|
||||
|
||||
# The font size ("10pt", "11pt" or "12pt").
|
||||
#
|
||||
# "pointsize": "10pt",
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# "preamble": "",
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# "figure_align": "htbp",
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, "prototorch.tex", "ProtoTorch Documentation",
|
||||
"Jensun Ravichandran", "manual"),
|
||||
]
|
||||
|
||||
# -- Options for manual page output ---------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)]
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, "prototorch", "ProtoTorch Documentation", author, "prototorch",
|
||||
"Prototype-based machine learning in PyTorch.",
|
||||
"Miscellaneous"),
|
||||
]
|
||||
|
||||
# Example configuration for intersphinx: refer to the Python standard library.
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/", None),
|
||||
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
|
||||
}
|
||||
|
||||
# -- Options for Epub output ----------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-epub-output
|
||||
|
||||
epub_cover = ()
|
||||
version = release
|
22
docs/source/index.rst
Normal file
22
docs/source/index.rst
Normal file
@@ -0,0 +1,22 @@
|
||||
.. ProtoTorch documentation master file
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
About ProtoTorch
|
||||
================
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 3
|
||||
:caption: Contents:
|
||||
|
||||
self
|
||||
api
|
||||
|
||||
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge
|
||||
research in prototype-based machine learning algorithms.
|
||||
|
||||
Indices
|
||||
=======
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
@@ -3,16 +3,17 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torchinfo import summary
|
||||
|
||||
# Prepare and preprocess the data
|
||||
scaler = StandardScaler()
|
||||
x_train, y_train = load_iris(True)
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
scaler.fit(x_train)
|
||||
x_train = scaler.transform(x_train)
|
||||
@@ -20,17 +21,19 @@ x_train = scaler.transform(x_train)
|
||||
|
||||
# Define the GLVQ model
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
"""GLVQ model."""
|
||||
def __init__(self):
|
||||
"""GLVQ model for training on 2D Iris data."""
|
||||
super().__init__()
|
||||
self.p1 = Prototypes1D(input_dim=2,
|
||||
prototypes_per_class=1,
|
||||
nclasses=3,
|
||||
prototype_initializer='zeros')
|
||||
self.proto_layer = Prototypes1D(
|
||||
input_dim=2,
|
||||
prototypes_per_class=3,
|
||||
nclasses=3,
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x_train, y_train])
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
plabels = self.p1.prototype_labels
|
||||
protos = self.proto_layer.prototypes
|
||||
plabels = self.proto_layer.prototype_labels
|
||||
dis = euclidean_distance(x, protos)
|
||||
return dis, plabels
|
||||
|
||||
@@ -38,21 +41,28 @@ class Model(torch.nn.Module):
|
||||
# Build the GLVQ model
|
||||
model = Model()
|
||||
|
||||
# Print summary using torchinfo (might be buggy/incorrect)
|
||||
print(summary(model))
|
||||
|
||||
# Optimize using SGD optimizer from `torch.optim`
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||
|
||||
x_in = torch.Tensor(x_train)
|
||||
y_in = torch.Tensor(y_train)
|
||||
|
||||
# Training loop
|
||||
title = 'Prototype Visualization'
|
||||
title = "Prototype Visualization"
|
||||
fig = plt.figure(title)
|
||||
for epoch in range(70):
|
||||
# Compute loss
|
||||
dis, plabels = model(x_in)
|
||||
loss = criterion([dis, plabels], y_in)
|
||||
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}')
|
||||
with torch.no_grad():
|
||||
pred = wtac(dis, plabels)
|
||||
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
||||
acc = 100. * correct / len(x_train)
|
||||
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%")
|
||||
|
||||
# Take a gradient descent step
|
||||
optimizer.zero_grad()
|
||||
@@ -60,22 +70,25 @@ for epoch in range(70):
|
||||
optimizer.step()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.p1.prototypes.data.numpy()
|
||||
protos = model.proto_layer.prototypes.data.numpy()
|
||||
if np.isnan(np.sum(protos)):
|
||||
print("Stopping training because of `nan` in prototypes.")
|
||||
break
|
||||
|
||||
# Visualize the data and the prototypes
|
||||
ax = fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel('Data dimension 1')
|
||||
ax.set_ylabel('Data dimension 2')
|
||||
cmap = 'viridis'
|
||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
||||
ax.set_xlabel("Data dimension 1")
|
||||
ax.set_ylabel("Data dimension 2")
|
||||
cmap = "viridis"
|
||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||
ax.scatter(protos[:, 0],
|
||||
protos[:, 1],
|
||||
c=plabels,
|
||||
cmap=cmap,
|
||||
edgecolor='k',
|
||||
marker='D',
|
||||
edgecolor="k",
|
||||
marker="D",
|
||||
s=50)
|
||||
|
||||
# Paint decision regions
|
||||
@@ -88,8 +101,8 @@ for epoch in range(70):
|
||||
|
||||
torch_input = torch.Tensor(mesh_input)
|
||||
d = model(torch_input)[0]
|
||||
y_pred = np.argmin(d.detach().numpy(),
|
||||
axis=1) # assume one prototype per class
|
||||
w_indices = torch.argmin(d, dim=1)
|
||||
y_pred = torch.index_select(plabels, 0, w_indices)
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
|
||||
# Plot voronoi regions
|
||||
|
102
examples/gmlvq_tecator.py
Normal file
102
examples/gmlvq_tecator.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""ProtoTorch "siamese" GMLVQ example using Tecator."""
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from prototorch.datasets.tecator import Tecator
|
||||
from prototorch.functions.distances import sed
|
||||
from prototorch.modules import Prototypes1D
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.utils.colors import get_legend_handles
|
||||
|
||||
# Prepare the dataset and dataloader
|
||||
train_data = Tecator(root="./artifacts", train=True)
|
||||
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
"""GMLVQ model as a siamese network."""
|
||||
super().__init__()
|
||||
x, y = train_data.data, train_data.targets
|
||||
self.p1 = Prototypes1D(input_dim=100,
|
||||
prototypes_per_class=2,
|
||||
nclasses=2,
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x, y])
|
||||
self.omega = torch.nn.Linear(in_features=100,
|
||||
out_features=100,
|
||||
bias=False)
|
||||
torch.nn.init.eye_(self.omega.weight)
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
plabels = self.p1.prototype_labels
|
||||
|
||||
# Process `x` and `protos` through `omega`
|
||||
x_map = self.omega(x)
|
||||
protos_map = self.omega(protos)
|
||||
|
||||
# Compute distances and output
|
||||
dis = sed(x_map, protos_map)
|
||||
return dis, plabels
|
||||
|
||||
|
||||
# Build the GLVQ model
|
||||
model = Model()
|
||||
|
||||
# Print a summary of the model
|
||||
print(model)
|
||||
|
||||
# Optimize using Adam optimizer from `torch.optim`
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1)
|
||||
criterion = GLVQLoss(squashing="identity", beta=10)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(150):
|
||||
epoch_loss = 0.0 # zero-out epoch loss
|
||||
optimizer.zero_grad() # zero-out gradients
|
||||
for xb, yb in train_loader:
|
||||
# Compute loss
|
||||
distances, plabels = model(xb)
|
||||
loss = criterion([distances, plabels], yb)
|
||||
epoch_loss += loss.item()
|
||||
# Backprop
|
||||
loss.backward()
|
||||
# Take a gradient descent step
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
print(f"Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}")
|
||||
|
||||
# Get the omega matrix form the model
|
||||
omega = model.omega.weight.data.numpy().T
|
||||
|
||||
# Visualize the lambda matrix
|
||||
title = "Lambda Matrix Visualization"
|
||||
fig = plt.figure(title)
|
||||
ax = fig.gca()
|
||||
ax.set_title(title)
|
||||
im = ax.imshow(omega.dot(omega.T), cmap="viridis")
|
||||
plt.show()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.p1.prototypes.data.numpy()
|
||||
plabels = model.p1.prototype_labels
|
||||
|
||||
# Visualize the prototypes
|
||||
title = "Tecator Prototypes"
|
||||
fig = plt.figure(title)
|
||||
ax = fig.gca()
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Spectral frequencies")
|
||||
ax.set_ylabel("Absorption")
|
||||
clabels = ["Class 0 - Low fat", "Class 1 - High fat"]
|
||||
handles, colors = get_legend_handles(clabels, marker="line", zero_indexed=True)
|
||||
for x, y in zip(protos, plabels):
|
||||
ax.plot(x, c=colors[int(y)])
|
||||
ax.legend(handles, clabels)
|
||||
plt.show()
|
162
examples/gtlvq_mnist.py
Normal file
162
examples/gtlvq_mnist.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
ProtoTorch GTLVQ example using MNIST data.
|
||||
The GTLVQ is placed as an classification model on
|
||||
top of a CNN, considered as featurer extractor.
|
||||
Initialization of subpsace and prototypes in
|
||||
Siamnese fashion
|
||||
For more info about GTLVQ see:
|
||||
DOI:10.1109/IJCNN.2016.7727534
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||
from prototorch.modules.models import GTLVQ
|
||||
|
||||
# Parameters and options
|
||||
n_epochs = 50
|
||||
batch_size_train = 64
|
||||
batch_size_test = 1000
|
||||
learning_rate = 0.1
|
||||
momentum = 0.5
|
||||
log_interval = 10
|
||||
cuda = "cuda:1"
|
||||
random_seed = 1
|
||||
device = torch.device(cuda if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Configures reproducability
|
||||
torch.manual_seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
# Prepare and preprocess the data
|
||||
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
||||
'./files/',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=torchvision.transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
||||
batch_size=batch_size_train,
|
||||
shuffle=True)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
||||
'./files/',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=torchvision.transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
||||
batch_size=batch_size_test,
|
||||
shuffle=True)
|
||||
|
||||
|
||||
# Define the GLVQ model plus appropriate feature extractor
|
||||
class CNNGTLVQ(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type="local",
|
||||
prototypes_per_class=2,
|
||||
bottleneck_dim=128,
|
||||
):
|
||||
super(CNNGTLVQ, self).__init__()
|
||||
|
||||
#Feature Extractor - Simple CNN
|
||||
self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
|
||||
nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
|
||||
nn.MaxPool2d(2), nn.Dropout(0.25),
|
||||
nn.Flatten(), nn.Linear(9216, bottleneck_dim),
|
||||
nn.Dropout(0.5), nn.LeakyReLU(),
|
||||
nn.LayerNorm(bottleneck_dim))
|
||||
|
||||
# Forward pass of subspace and prototype initialization data through feature extractor
|
||||
subspace_data = self.fe(subspace_data)
|
||||
prototype_data[0] = self.fe(prototype_data[0])
|
||||
|
||||
# Initialization of GTLVQ
|
||||
self.gtlvq = GTLVQ(num_classes,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type=tangent_projection_type,
|
||||
feature_dim=bottleneck_dim,
|
||||
prototypes_per_class=prototypes_per_class)
|
||||
|
||||
def forward(self, x):
|
||||
# Feature Extraction
|
||||
x = self.fe(x)
|
||||
|
||||
# GTLVQ Forward pass
|
||||
dis = self.gtlvq(x)
|
||||
return dis
|
||||
|
||||
|
||||
# Get init data
|
||||
subspace_data = torch.cat(
|
||||
[next(iter(train_loader))[0],
|
||||
next(iter(test_loader))[0]])
|
||||
prototype_data = next(iter(train_loader))
|
||||
|
||||
# Build the CNN GTLVQ model
|
||||
model = CNNGTLVQ(10,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type="local",
|
||||
bottleneck_dim=128).to(device)
|
||||
|
||||
# Optimize using SGD optimizer from `torch.optim`
|
||||
optimizer = torch.optim.Adam([{
|
||||
'params': model.fe.parameters()
|
||||
}, {
|
||||
'params': model.gtlvq.parameters()
|
||||
}],
|
||||
lr=learning_rate)
|
||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(n_epochs):
|
||||
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
||||
model.train()
|
||||
x_train, y_train = x_train.to(device), y_train.to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
distances = model(x_train)
|
||||
plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||
|
||||
# Compute loss.
|
||||
loss = criterion([distances, plabels], y_train)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
|
||||
model.gtlvq.orthogonalize_subspace()
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
||||
print(
|
||||
f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||
Train Acc: {acc.item():02.02f}')
|
||||
|
||||
# Test
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
for x_test, y_test in test_loader:
|
||||
x_test, y_test = x_test.to(device), y_test.to(device)
|
||||
test_distances = model(torch.tensor(x_test))
|
||||
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||
i = torch.argmin(test_distances, 1)
|
||||
correct += torch.sum(y_test == test_plabels[i])
|
||||
total += y_test.size(0)
|
||||
print('Accuracy of the network on the test images: %d %%' %
|
||||
(torch.true_divide(correct, total) * 100))
|
||||
|
||||
# Save the model
|
||||
PATH = './glvq_mnist_model.pth'
|
||||
torch.save(model.state_dict(), PATH)
|
106
examples/lgmlvq_iris.py
Normal file
106
examples/lgmlvq_iris.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""ProtoTorch LGMLVQ example using 2D Iris data."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
from prototorch.functions.competitions import stratified_min
|
||||
from prototorch.functions.distances import lomega_distance
|
||||
from prototorch.functions.init import eye_
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
# Prepare training data
|
||||
x_train, y_train = load_iris(True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
|
||||
|
||||
# Define the model
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
"""Local-GMLVQ model."""
|
||||
super().__init__()
|
||||
self.p1 = Prototypes1D(input_dim=2,
|
||||
prototype_distribution=[1, 2, 2],
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x_train, y_train])
|
||||
omegas = torch.zeros(5, 2, 2)
|
||||
self.omegas = torch.nn.Parameter(omegas)
|
||||
eye_(self.omegas)
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
plabels = self.p1.prototype_labels
|
||||
omegas = self.omegas
|
||||
dis = lomega_distance(x, protos, omegas)
|
||||
return dis, plabels
|
||||
|
||||
|
||||
# Build the model
|
||||
model = Model()
|
||||
|
||||
# Optimize using Adam optimizer from `torch.optim`
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||
|
||||
x_in = torch.Tensor(x_train)
|
||||
y_in = torch.Tensor(y_train)
|
||||
|
||||
# Training loop
|
||||
title = "Prototype Visualization"
|
||||
fig = plt.figure(title)
|
||||
for epoch in range(100):
|
||||
# Compute loss
|
||||
dis, plabels = model(x_in)
|
||||
loss = criterion([dis, plabels], y_in)
|
||||
y_pred = np.argmin(stratified_min(dis, plabels).detach().numpy(), axis=1)
|
||||
acc = accuracy_score(y_train, y_pred)
|
||||
log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} "
|
||||
log_string += f"Acc: {acc * 100:05.02f}%"
|
||||
print(log_string)
|
||||
|
||||
# Take a gradient descent step
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.p1.prototypes.data.numpy()
|
||||
|
||||
# Visualize the data and the prototypes
|
||||
ax = fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Data dimension 1")
|
||||
ax.set_ylabel("Data dimension 2")
|
||||
cmap = "viridis"
|
||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
||||
ax.scatter(protos[:, 0],
|
||||
protos[:, 1],
|
||||
c=plabels,
|
||||
cmap=cmap,
|
||||
edgecolor='k',
|
||||
marker='D',
|
||||
s=50)
|
||||
|
||||
# Paint decision regions
|
||||
x = np.vstack((x_train, protos))
|
||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
||||
np.arange(y_min, y_max, 1 / 50))
|
||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||
|
||||
d, plabels = model(torch.Tensor(mesh_input))
|
||||
y_pred = np.argmin(stratified_min(d, plabels).detach().numpy(), axis=1)
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
|
||||
# Plot voronoi regions
|
||||
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
|
||||
|
||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||
|
||||
plt.pause(0.1)
|
@@ -1,6 +1,6 @@
|
||||
"""ProtoTorch package."""
|
||||
|
||||
__version__ = '0.1.1-rc0'
|
||||
__version__ = '0.2.0'
|
||||
|
||||
from prototorch import datasets, functions, modules
|
||||
|
||||
|
@@ -1,6 +1,8 @@
|
||||
"""ProtoTorch distance functions."""
|
||||
|
||||
import torch
|
||||
from prototorch.functions.helper import equal_int_shape, _int_and_mixed_shape, _check_shapes
|
||||
import numpy as np
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
@@ -71,5 +73,155 @@ def lomega_distance(x, y, omegas):
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
||||
r""" Computes an euclidean distanes matrix given two distinct vectors.
|
||||
last dimension must be the vector dimension!
|
||||
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
||||
|
||||
x.shape = (number_of_x_vectors, vector_dim)
|
||||
y.shape = (number_of_y_vectors, vector_dim)
|
||||
|
||||
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
||||
"""
|
||||
for tensor in [x, y]:
|
||||
if tensor.ndim != 2:
|
||||
raise ValueError(
|
||||
'The tensor dimension must be two. You provide: tensor.ndim=' +
|
||||
str(tensor.ndim) + '.')
|
||||
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||
raise ValueError(
|
||||
'The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]='
|
||||
+ str(tuple(x.shape)[1]) + ' and tuple(y.shape)(y)[1]=' +
|
||||
str(tuple(y.shape)[1]) + '.')
|
||||
|
||||
y = torch.transpose(y)
|
||||
|
||||
diss = torch.sum(x**2, axis=1,
|
||||
keepdims=True) - 2 * torch.dot(x, y) + torch.sum(
|
||||
y**2, axis=0, keepdims=True)
|
||||
|
||||
if not squared:
|
||||
if epsilon == 0:
|
||||
diss = torch.sqrt(diss)
|
||||
else:
|
||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||
|
||||
return diss
|
||||
|
||||
|
||||
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
||||
r""" Tangent distances based on the tensorflow implementation of Sascha Saralajews
|
||||
For more info about Tangen distances see DOI:10.1109/IJCNN.2016.7727534.
|
||||
The subspaces is always assumed as transposed and must be orthogonal!
|
||||
For local non sparse signals subspaces must be provided!
|
||||
shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||
subspace should be orthogonalized
|
||||
Pytorch implementation of Sascha Saralajew's tensorflow code.
|
||||
Translation by Christoph Raab
|
||||
"""
|
||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
||||
subspace_int_shape = tuple(subspaces.shape)
|
||||
|
||||
# check if the shapes are correct
|
||||
_check_shapes(signal_int_shape, proto_int_shape)
|
||||
|
||||
atom_axes = list(range(3, len(signal_int_shape)))
|
||||
# for sparse signals, we use the memory efficient implementation
|
||||
if signal_int_shape[1] == 1:
|
||||
signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
|
||||
|
||||
if len(atom_axes) > 1:
|
||||
protos = torch.reshape(protos, [proto_shape[0], -1])
|
||||
|
||||
if subspaces.ndim == 2:
|
||||
# clean solution without map if the matrix_scope is global
|
||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||
subspaces, torch.transpose(subspaces))
|
||||
|
||||
projected_signals = torch.dot(signals, projectors)
|
||||
projected_protos = torch.dot(protos, projectors)
|
||||
|
||||
diss = euclidean_distance_matrix(projected_signals,
|
||||
projected_protos,
|
||||
squared=squared,
|
||||
epsilon=epsilon)
|
||||
|
||||
diss = torch.reshape(
|
||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||
|
||||
return torch.permute(diss, [0, 2, 1])
|
||||
|
||||
else:
|
||||
|
||||
# no solution without map possible --> memory efficient but slow!
|
||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||
subspaces,
|
||||
subspaces) #K.batch_dot(subspaces, subspaces, [2, 2])
|
||||
|
||||
projected_protos = (protos @ subspaces
|
||||
).T #K.batch_dot(projectors, protos, [1, 1]))
|
||||
|
||||
def projected_norm(projector):
|
||||
return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
||||
|
||||
diss = torch.transpose(map(projected_norm, projectors)) \
|
||||
- 2 * torch.dot(signals, projected_protos) \
|
||||
+ torch.sum(projected_protos**2, axis=0, keepdims=True)
|
||||
|
||||
if not squared:
|
||||
if epsilon == 0:
|
||||
diss = torch.sqrt(diss)
|
||||
else:
|
||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||
|
||||
diss = torch.reshape(
|
||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||
|
||||
return torch.permute(diss, [0, 2, 1])
|
||||
|
||||
else:
|
||||
signals = signals.permute([0, 2, 1] + atom_axes)
|
||||
|
||||
diff = signals - protos
|
||||
|
||||
# global tangent space
|
||||
if subspaces.ndim == 2:
|
||||
#Scope Projectors
|
||||
projectors = subspaces #
|
||||
|
||||
#Scope: Tangentspace Projections
|
||||
diff = torch.reshape(
|
||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||
projected_diff = diff @ projectors
|
||||
projected_diff = torch.reshape(
|
||||
projected_diff,
|
||||
(signal_shape[0], signal_shape[2], signal_shape[1]) +
|
||||
signal_shape[3:])
|
||||
|
||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
return diss.permute([0, 2, 1])
|
||||
|
||||
# local tangent spaces
|
||||
else:
|
||||
# Scope: Calculate Projectors
|
||||
projectors = subspaces
|
||||
|
||||
# Scope: Tangentspace Projections
|
||||
diff = torch.reshape(
|
||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||
diff = diff.permute([1, 0, 2])
|
||||
projected_diff = torch.bmm(diff, projectors)
|
||||
projected_diff = torch.reshape(
|
||||
projected_diff,
|
||||
(signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||
signal_shape[3:])
|
||||
|
||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
return diss.permute([1, 0, 2]).squeeze(-1)
|
||||
|
||||
|
||||
# Aliases
|
||||
sed = squared_euclidean_distance
|
||||
|
89
prototorch/functions/helper.py
Normal file
89
prototorch/functions/helper.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import torch
|
||||
|
||||
|
||||
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
||||
"""Computes the accuracy of a prototype based model.
|
||||
via Winner-Takes-All rule.
|
||||
Requirement:
|
||||
y_pred.shape == y_true.shape
|
||||
unique(y_pred) in plabels
|
||||
"""
|
||||
with torch.no_grad():
|
||||
idx = torch.argmin(y_pred, axis=1)
|
||||
return torch.true_divide(torch.sum(y_true == plabels[idx]),
|
||||
len(y_pred)) * 100
|
||||
|
||||
|
||||
def predict_label(y_pred, plabels):
|
||||
r""" Predicts labels given a prediction of a prototype based model.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return plabels[torch.argmin(y_pred, 1)]
|
||||
|
||||
|
||||
def mixed_shape(inputs):
|
||||
if not torch.is_tensor(inputs):
|
||||
raise ValueError('Input must be a tensor.')
|
||||
else:
|
||||
int_shape = list(inputs.shape)
|
||||
# sometimes int_shape returns mixed integer types
|
||||
int_shape = [int(i) if i is not None else i for i in int_shape]
|
||||
tensor_shape = inputs.shape
|
||||
|
||||
for i, s in enumerate(int_shape):
|
||||
if s is None:
|
||||
int_shape[i] = tensor_shape[i]
|
||||
return tuple(int_shape)
|
||||
|
||||
|
||||
def equal_int_shape(shape_1, shape_2):
|
||||
if not isinstance(shape_1,
|
||||
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
||||
raise ValueError('Input shapes must list or tuple.')
|
||||
for shape in [shape_1, shape_2]:
|
||||
if not all([isinstance(x, int) or x is None for x in shape]):
|
||||
raise ValueError(
|
||||
'Input shapes must be list or tuple of int and None values.')
|
||||
|
||||
if len(shape_1) != len(shape_2):
|
||||
return False
|
||||
else:
|
||||
for axis, value in enumerate(shape_1):
|
||||
if value is not None and shape_2[axis] not in {value, None}:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_shapes(signal_int_shape, proto_int_shape):
|
||||
if len(signal_int_shape) < 4:
|
||||
raise ValueError(
|
||||
"The number of signal dimensions must be >=4. You provide: " +
|
||||
str(len(signal_int_shape)))
|
||||
|
||||
if len(proto_int_shape) < 2:
|
||||
raise ValueError(
|
||||
"The number of proto dimensions must be >=2. You provide: " +
|
||||
str(len(proto_int_shape)))
|
||||
|
||||
if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]):
|
||||
raise ValueError(
|
||||
"The atom shape of signals must be equal protos. You provide: signals.shape[3:]="
|
||||
+ str(signal_int_shape[3:]) + " != protos.shape[1:]=" +
|
||||
str(proto_int_shape[1:]))
|
||||
|
||||
# not a sparse signal
|
||||
if signal_int_shape[1] != 1:
|
||||
if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]):
|
||||
raise ValueError(
|
||||
"If the signal is not sparse, the number of prototypes must be equal in signals and "
|
||||
"protos. You provide: " + str(signal_int_shape[1]) + " != " +
|
||||
str(proto_int_shape[0]))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _int_and_mixed_shape(tensor):
|
||||
shape = mixed_shape(tensor)
|
||||
int_shape = tuple([i if isinstance(i, int) else None for i in shape])
|
||||
|
||||
return shape, int_shape
|
@@ -76,7 +76,11 @@ def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
|
||||
|
||||
@register_initializer
|
||||
def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
def stratified_random(x_train,
|
||||
y_train,
|
||||
prototype_distribution,
|
||||
one_hot=True,
|
||||
epsilon=1e-7):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(nprotos, pdim)
|
||||
@@ -89,7 +93,7 @@ def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
xl = x_train[matcher]
|
||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||
random_xl = xl[rand_index]
|
||||
protos[i] = random_xl
|
||||
protos[i] = random_xl + epsilon
|
||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
37
prototorch/functions/normalization.py
Normal file
37
prototorch/functions/normalization.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def orthogonalization(tensors):
|
||||
r""" Orthogonalization of a given tensor via polar decomposition.
|
||||
"""
|
||||
u, _, v = torch.svd(tensors, compute_uv=True)
|
||||
u_shape = tuple(list(u.shape))
|
||||
v_shape = tuple(list(v.shape))
|
||||
|
||||
# reshape to (num x N x M)
|
||||
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
|
||||
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
|
||||
|
||||
out = u @ v.permute([0, 2, 1])
|
||||
|
||||
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def trace_normalization(tensors):
|
||||
r""" Trace normalization
|
||||
"""
|
||||
epsilon = torch.tensor([1e-10], dtype=torch.float64)
|
||||
# Scope trace_normalization
|
||||
constant = torch.trace(tensors)
|
||||
|
||||
if epsilon != 0:
|
||||
constant = torch.max(constant, epsilon)
|
||||
|
||||
return tensors / constant
|
@@ -15,6 +15,6 @@ class GLVQLoss(torch.nn.Module):
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
mu = glvq_loss(distances, targets, plabels)
|
||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
||||
|
190
prototorch/modules/models.py
Normal file
190
prototorch/modules/models.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
r""" Generalized Tangent Learning Vector Quantization
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes: int
|
||||
Number of classes of the given classification problem.
|
||||
|
||||
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
|
||||
Subspace data for the point approximation, required
|
||||
|
||||
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
|
||||
prototype data for initalization of the prototypes used in GTLVQ.
|
||||
|
||||
subspace_size: int (default=256,optional)
|
||||
Subspace dimension of the Projectors. Currently only supported
|
||||
with tagnent_projection_type=global.
|
||||
|
||||
tangent_projection_type: string
|
||||
Specifies the tangent projection type
|
||||
options: local
|
||||
local_proj
|
||||
global
|
||||
local: computes the tangent distances without emphasizing projected
|
||||
data. Only distances are available
|
||||
local_proj: computs tangent distances and returns the projected data
|
||||
for further use. Be careful: data is repeated by number of prototypes
|
||||
global: Number of subspaces is set to one and every prototypes
|
||||
uses the same.
|
||||
|
||||
prototypes_per_class: int (default=2,optional)
|
||||
Number of prototypes per class
|
||||
|
||||
feature_dim: int (default=256)
|
||||
Dimensionality of the feature space specified as integer.
|
||||
Prototype dimension.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The GTLVQ [1] is a prototype-based classification learning model. The
|
||||
GTLVQ uses the Tangent-Distances for a local point approximation
|
||||
of an assumed data manifold via prototypial representations.
|
||||
|
||||
The GTLVQ requires subspace projectors for transforming the data
|
||||
and prototypes into the affine subspace. Every prototype is
|
||||
equipped with a specific subpspace and represents a point
|
||||
approximation of the assumed manifold.
|
||||
|
||||
In practice prototypes and data are projected on this manifold
|
||||
and pairwise euclidean distance computes.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
|
||||
in classification based on manifolc. models and its relation
|
||||
to tangent metric learning. In: 2017 International Joint
|
||||
Conference on Neural Networks (IJCNN).
|
||||
Bd. 2017-May : IEEE, 2017, S. 1756–1765
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
subspace_data=None,
|
||||
prototype_data=None,
|
||||
subspace_size=256,
|
||||
tangent_projection_type='local',
|
||||
prototypes_per_class=2,
|
||||
feature_dim=256,
|
||||
):
|
||||
super(GTLVQ, self).__init__()
|
||||
|
||||
self.num_protos = num_classes * prototypes_per_class
|
||||
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
||||
self.feature_dim = feature_dim
|
||||
|
||||
if subspace_data is None:
|
||||
raise ValueError('Init Data must be specified!')
|
||||
|
||||
self.tpt = tangent_projection_type
|
||||
with torch.no_grad():
|
||||
if self.tpt == 'local' or self.tpt == 'local_proj':
|
||||
self.init_local_subspace(subspace_data)
|
||||
elif self.tpt == 'global':
|
||||
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||
else:
|
||||
self.subspaces = None
|
||||
|
||||
# Hypothesis-Margin-Classifier
|
||||
self.cls = Prototypes1D(input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=prototype_data)
|
||||
|
||||
def forward(self, x):
|
||||
# Tangent Projection
|
||||
if self.tpt == 'local_proj':
|
||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2)
|
||||
dis, proj_x = self.local_tangent_projection(x_conform)
|
||||
|
||||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
||||
self.feature_dim)
|
||||
return proj_x, dis
|
||||
elif self.tpt == "local":
|
||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2)
|
||||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
||||
self.subspaces)
|
||||
elif self.tpt == "gloabl":
|
||||
dis = self.global_tangent_distances(x)
|
||||
else:
|
||||
dis = (x @ self.cls.prototypes.T) / (
|
||||
torch.norm(x, dim=1, keepdim=True) @ torch.norm(
|
||||
self.cls.prototypes, dim=1, keepdim=True).T)
|
||||
return dis
|
||||
|
||||
def init_gobal_subspace(self, data, num_subspaces):
|
||||
_, _, v = torch.svd(data)
|
||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
subspaces = subspace[:, :num_subspaces]
|
||||
self.subspaces = torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True)
|
||||
|
||||
def init_local_subspace(self, data):
|
||||
_, _, v = torch.svd(data)
|
||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
||||
self.num_protos, 0)
|
||||
self.subspaces = torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True)
|
||||
|
||||
def global_tangent_distances(self, x):
|
||||
# Tangent Projection
|
||||
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
|
||||
# Euclidean Distance
|
||||
return euclidean_distance_matrix(x, projected_prototypes)
|
||||
|
||||
def local_tangent_projection(self,
|
||||
signals):
|
||||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
||||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||
# subspace should be orthogonalized
|
||||
# Origin Source Code
|
||||
# Origin Author:
|
||||
protos = self.cls.prototypes
|
||||
subspaces = self.subspaces
|
||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||
_, proto_int_shape = _int_and_mixed_shape(protos)
|
||||
|
||||
# check if the shapes are correct
|
||||
_check_shapes(signal_int_shape, proto_int_shape)
|
||||
|
||||
# Tangent Data Projections
|
||||
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
|
||||
data = signals.squeeze(2).permute([1, 0, 2])
|
||||
projected_data = torch.bmm(data, subspaces)
|
||||
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
|
||||
diff = projected_data - projected_protos
|
||||
projected_diff = torch.reshape(
|
||||
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||
signal_shape[3:])
|
||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
|
||||
|
||||
def get_parameters(self):
|
||||
return {
|
||||
"params": self.cls.prototypes,
|
||||
}, {
|
||||
"params": self.subspaces
|
||||
}
|
||||
|
||||
def orthogonalize_subspace(self):
|
||||
if self.subspaces is not None:
|
||||
with torch.no_grad():
|
||||
ortho_subpsaces = orthogonalization(
|
||||
self.subspaces
|
||||
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
|
||||
self.subspaces)
|
||||
self.subspaces.copy_(ortho_subpsaces)
|
@@ -2,11 +2,8 @@
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import sed
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
|
||||
|
||||
@@ -28,40 +25,9 @@ class _Prototypes(torch.nn.Module):
|
||||
|
||||
|
||||
class Prototypes1D(_Prototypes):
|
||||
r"""Create a learnable set of one-dimensional prototypes.
|
||||
"""Create a learnable set of one-dimensional prototypes.
|
||||
|
||||
TODO Complete this doc-string
|
||||
|
||||
Kwargs:
|
||||
prototypes_per_class: number of prototypes to use per class.
|
||||
Default: ``1``
|
||||
prototype_initializer: prototype initializer.
|
||||
Default: ``'ones'``
|
||||
prototype_distribution: prototype distribution vector.
|
||||
Default: ``None``
|
||||
input_dim: dimension of the incoming data.
|
||||
nclasses: number of classes.
|
||||
data: If set to ``None``, data-dependent initializers will be ignored.
|
||||
Default: ``None``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, H_{in})`
|
||||
where :math:`H_{in} = \text{input_dim}`.
|
||||
- Output: :math:`(N, H_{out})`
|
||||
where :math:`H_{out} = \text{total_prototypes}`.
|
||||
|
||||
Attributes:
|
||||
prototypes: the learnable weights of the module of shape
|
||||
:math:`(\text{total_prototypes}, \text{prototype_dimension})`.
|
||||
prototype_labels: the non-learnable labels of the prototypes.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> p = Prototypes1D(input_dim=20, nclasses=10)
|
||||
>>> input = torch.randn(128, 20)
|
||||
>>> output = m(input)
|
||||
>>> print(output.size())
|
||||
torch.Size([20, 10])
|
||||
TODO Complete this doc-string.
|
||||
"""
|
||||
def __init__(self,
|
||||
prototypes_per_class=1,
|
||||
@@ -162,4 +128,5 @@ class Prototypes1D(_Prototypes):
|
||||
|
||||
# Register module parameters
|
||||
self.prototypes = torch.nn.Parameter(prototypes)
|
||||
self.prototype_labels = prototype_labels
|
||||
self.prototype_labels = torch.nn.Parameter(
|
||||
prototype_labels.type(dtype)).requires_grad_(False)
|
||||
|
1
prototorch/utils/__init__.py
Normal file
1
prototorch/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .colors import color_scheme, get_legend_handles
|
74
prototorch/utils/colors.py
Normal file
74
prototorch/utils/colors.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""ProtoFlow color utilities."""
|
||||
|
||||
from matplotlib import cm
|
||||
from matplotlib.colors import Normalize
|
||||
from matplotlib.colors import to_hex
|
||||
from matplotlib.colors import to_rgb
|
||||
import matplotlib.lines as mlines
|
||||
|
||||
|
||||
def color_scheme(n, cmap="viridis", form="hex", tikz=False,
|
||||
zero_indexed=False):
|
||||
"""Return *n* colors from the color scheme.
|
||||
|
||||
Arguments:
|
||||
n (int): number of colors to return
|
||||
|
||||
Keyword Arguments:
|
||||
cmap (str): Name of a matplotlib `colormap\
|
||||
<https://matplotlib.org/3.1.1/gallery/color/colormap_reference.html>`_.
|
||||
form (str): Colorformat (supports "hex" and "rgb").
|
||||
tikz (bool): Output as `TikZ <https://github.com/pgf-tikz/pgf>`_
|
||||
command.
|
||||
zero_indexed (bool): Use zero indexing for output array.
|
||||
|
||||
Returns:
|
||||
(list): List of colors
|
||||
"""
|
||||
cmap = cm.get_cmap(cmap)
|
||||
colornorm = Normalize(vmin=1, vmax=n)
|
||||
hex_map = dict()
|
||||
rgb_map = dict()
|
||||
for cl in range(1, n + 1):
|
||||
if zero_indexed:
|
||||
hex_map[cl - 1] = to_hex(cmap(colornorm(cl)))
|
||||
rgb_map[cl - 1] = to_rgb(cmap(colornorm(cl)))
|
||||
else:
|
||||
hex_map[cl] = to_hex(cmap(colornorm(cl)))
|
||||
rgb_map[cl] = to_rgb(cmap(colornorm(cl)))
|
||||
if tikz:
|
||||
for k, v in rgb_map.items():
|
||||
print(f"\\definecolor{{color-{k}}}{{rgb}}{{{v[0]},{v[1]},{v[2]}}}")
|
||||
if form == "hex":
|
||||
return hex_map
|
||||
elif form == "rgb":
|
||||
return rgb_map
|
||||
else:
|
||||
return hex_map
|
||||
|
||||
|
||||
def get_legend_handles(labels, marker="dots", zero_indexed=False):
|
||||
"""Return matplotlib legend handles and colors."""
|
||||
handles = list()
|
||||
n = len(labels)
|
||||
colors = color_scheme(n,
|
||||
cmap="viridis",
|
||||
form="hex",
|
||||
zero_indexed=zero_indexed)
|
||||
for label, color in zip(labels, colors.values()):
|
||||
if marker == "dots":
|
||||
handle = mlines.Line2D([], [],
|
||||
color="white",
|
||||
markerfacecolor=color,
|
||||
marker="o",
|
||||
markersize=10,
|
||||
markeredgecolor="k",
|
||||
label=label)
|
||||
else:
|
||||
handle = mlines.Line2D([], [],
|
||||
color=color,
|
||||
marker="",
|
||||
markersize=15,
|
||||
label=label)
|
||||
handles.append(handle)
|
||||
return handles, colors
|
89
setup.py
89
setup.py
@@ -3,48 +3,69 @@
|
||||
from setuptools import setup
|
||||
from setuptools import find_packages
|
||||
|
||||
PROJECT_URL = 'https://github.com/si-cim/prototorch'
|
||||
DOWNLOAD_URL = 'https://github.com/si-cim/prototorch.git'
|
||||
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
||||
|
||||
with open('README.md', 'r') as fh:
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setup(name='prototorch',
|
||||
version='0.1.1-rc0',
|
||||
description='Highly extensible, GPU-supported '
|
||||
'Learning Vector Quantization (LVQ) toolbox '
|
||||
'built using PyTorch and its nn API.',
|
||||
INSTALL_REQUIRES = [
|
||||
"torch>=1.3.1",
|
||||
"torchvision>=0.5.0",
|
||||
"numpy>=1.9.1",
|
||||
]
|
||||
DOCS = [
|
||||
"recommonmark",
|
||||
"sphinx",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib-katex",
|
||||
]
|
||||
DATASETS = [
|
||||
"requests",
|
||||
"tqdm",
|
||||
]
|
||||
EXAMPLES = [
|
||||
"sklearn",
|
||||
"matplotlib",
|
||||
"torchinfo",
|
||||
]
|
||||
TESTS = ["pytest"]
|
||||
ALL = DOCS + DATASETS + EXAMPLES + TESTS
|
||||
|
||||
setup(name="prototorch",
|
||||
version="0.2.0",
|
||||
description="Highly extensible, GPU-supported "
|
||||
"Learning Vector Quantization (LVQ) toolbox "
|
||||
"built using PyTorch and its nn API.",
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
author='Jensun Ravichandran',
|
||||
author_email='jjensun@gmail.com',
|
||||
long_description_content_type="text/markdown",
|
||||
author="Jensun Ravichandran",
|
||||
author_email="jjensun@gmail.com",
|
||||
url=PROJECT_URL,
|
||||
download_url=DOWNLOAD_URL,
|
||||
license='MIT',
|
||||
install_requires=[
|
||||
'torch>=1.3.1',
|
||||
'torchvision>=0.5.0',
|
||||
'numpy>=1.9.1',
|
||||
],
|
||||
license="MIT",
|
||||
install_requires=INSTALL_REQUIRES,
|
||||
extras_require={
|
||||
'datasets': ['requests'],
|
||||
'examples': [
|
||||
'sklearn',
|
||||
'matplotlib',
|
||||
],
|
||||
'tests': ['pytest'],
|
||||
"docs": DOCS,
|
||||
"datasets": DATASETS,
|
||||
"examples": EXAMPLES,
|
||||
"tests": TESTS,
|
||||
"all": ALL,
|
||||
},
|
||||
classifiers=[
|
||||
'Development Status :: 2 - Pre-Alpha', 'Environment :: Console',
|
||||
'Intended Audience :: Developers', 'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Operating System :: OS Independent',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development :: Libraries',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules'
|
||||
"Development Status :: 2 - Pre-Alpha",
|
||||
"Environment :: Console",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Natural Language :: English",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
],
|
||||
packages=find_packages())
|
||||
|
@@ -199,7 +199,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
|
||||
def test_prototypes1d_func_initializer(self):
|
||||
def my_initializer(*args, **kwargs):
|
||||
return torch.full((2, 99), 99), torch.tensor([0, 1])
|
||||
return torch.full((2, 99), 99.0), torch.tensor([0, 1])
|
||||
|
||||
p1 = prototypes.Prototypes1D(input_dim=99,
|
||||
nclasses=2,
|
||||
@@ -245,7 +245,7 @@ class TestLosses(unittest.TestCase):
|
||||
def test_glvqloss_init(self):
|
||||
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
|
||||
|
||||
def test_glvqloss_forward(self):
|
||||
def test_glvqloss_forward_1ppc(self):
|
||||
criterion = losses.GLVQLoss(margin=0,
|
||||
squashing='sigmoid_beta',
|
||||
beta=100)
|
||||
@@ -257,5 +257,23 @@ class TestLosses(unittest.TestCase):
|
||||
loss_value = loss.item()
|
||||
self.assertAlmostEqual(loss_value, 0.0)
|
||||
|
||||
def test_glvqloss_forward_2ppc(self):
|
||||
criterion = losses.GLVQLoss(margin=0,
|
||||
squashing='sigmoid_beta',
|
||||
beta=100)
|
||||
d = torch.stack([
|
||||
torch.ones(100),
|
||||
torch.ones(100),
|
||||
torch.zeros(100),
|
||||
torch.ones(100)
|
||||
],
|
||||
dim=1)
|
||||
labels = torch.tensor([0, 0, 1, 1])
|
||||
targets = torch.ones(100)
|
||||
outputs = [d, labels]
|
||||
loss = criterion(outputs, targets)
|
||||
loss_value = loss.item()
|
||||
self.assertAlmostEqual(loss_value, 0.0)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user