Merge version 0.2.0 into feature/plugin-architecture.
This commit is contained in:
		@@ -1,5 +1,5 @@
 | 
				
			|||||||
[bumpversion]
 | 
					[bumpversion]
 | 
				
			||||||
current_version = 0.1.1-rc0
 | 
					current_version = 0.2.0
 | 
				
			||||||
commit = True
 | 
					commit = True
 | 
				
			||||||
tag = True
 | 
					tag = True
 | 
				
			||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
 | 
					parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
 | 
				
			||||||
@@ -19,3 +19,4 @@ values =
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
[bumpversion:file:./prototorch/__init__.py]
 | 
					[bumpversion:file:./prototorch/__init__.py]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[bumpversion:file:./docs/source/conf.py]
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										27
									
								
								.readthedocs.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								.readthedocs.yml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
				
			|||||||
 | 
					# .readthedocs.yml
 | 
				
			||||||
 | 
					# Read the Docs configuration file
 | 
				
			||||||
 | 
					# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Required
 | 
				
			||||||
 | 
					version: 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Build documentation in the docs/ directory with Sphinx
 | 
				
			||||||
 | 
					sphinx:
 | 
				
			||||||
 | 
					  configuration: docs/source/conf.py
 | 
				
			||||||
 | 
					  fail_on_warning: true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Build documentation with MkDocs
 | 
				
			||||||
 | 
					# mkdocs:
 | 
				
			||||||
 | 
					#   configuration: mkdocs.yml
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Optionally build your docs in additional formats such as PDF and ePub
 | 
				
			||||||
 | 
					formats: all
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Optionally set the version of Python and requirements required to build your docs
 | 
				
			||||||
 | 
					python:
 | 
				
			||||||
 | 
					  version: 3.8
 | 
				
			||||||
 | 
					  install:
 | 
				
			||||||
 | 
					    - method: pip
 | 
				
			||||||
 | 
					      path: .
 | 
				
			||||||
 | 
					      extra_requirements:
 | 
				
			||||||
 | 
					        - all
 | 
				
			||||||
							
								
								
									
										18
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								README.md
									
									
									
									
									
								
							@@ -45,22 +45,8 @@ pip install -e .[all]
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
## Documentation
 | 
					## Documentation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The documentation is available at <https://prototorch.readthedocs.io/en/latest/>
 | 
					The documentation is available at <https://www.prototorch.ml/en/latest/>. Should
 | 
				
			||||||
 | 
					that link not work try <https://prototorch.readthedocs.io/en/latest/>.
 | 
				
			||||||
## Usage
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### For researchers
 | 
					 | 
				
			||||||
ProtoTorch is modular. It is very easy to use the modular pieces provided by
 | 
					 | 
				
			||||||
ProtoTorch, like the layers, losses, callbacks and metrics to build your own
 | 
					 | 
				
			||||||
prototype-based(instance-based) models. These pieces blend-in seamlessly with
 | 
					 | 
				
			||||||
Keras allowing you to mix and match the modules from ProtoFlow with other
 | 
					 | 
				
			||||||
modules in `torch.nn`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### For engineers
 | 
					 | 
				
			||||||
ProtoTorch comes prepackaged with many popular Learning Vector Quantization
 | 
					 | 
				
			||||||
(LVQ)-like algorithms in a convenient API. If you would simply like to be able
 | 
					 | 
				
			||||||
to use those algorithms to train large ML models on a GPU, ProtoTorch lets you
 | 
					 | 
				
			||||||
do this without requiring a black-belt in high-performance Tensor computing.
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Bibtex
 | 
					## Bibtex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,10 @@
 | 
				
			|||||||
# ProtoTorch Releases
 | 
					# ProtoTorch Releases
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Release 0.2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Includes
 | 
				
			||||||
 | 
					- Fixes in example scripts.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Release 0.1.1-dev0
 | 
					## Release 0.1.1-dev0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Includes
 | 
					### Includes
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										20
									
								
								docs/Makefile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								docs/Makefile
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
				
			|||||||
 | 
					# Minimal makefile for Sphinx documentation
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# You can set these variables from the command line, and also
 | 
				
			||||||
 | 
					# from the environment for the first two.
 | 
				
			||||||
 | 
					SPHINXOPTS    ?=
 | 
				
			||||||
 | 
					SPHINXBUILD   ?= python3 -m sphinx
 | 
				
			||||||
 | 
					SOURCEDIR     = source
 | 
				
			||||||
 | 
					BUILDDIR      = build
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Put it first so that "make" without argument is like "make help".
 | 
				
			||||||
 | 
					help:
 | 
				
			||||||
 | 
						@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.PHONY: help Makefile
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Catch-all target: route all unknown targets to Sphinx using the new
 | 
				
			||||||
 | 
					# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
 | 
				
			||||||
 | 
					%: Makefile
 | 
				
			||||||
 | 
						@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
 | 
				
			||||||
							
								
								
									
										35
									
								
								docs/make.bat
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								docs/make.bat
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,35 @@
 | 
				
			|||||||
 | 
					@ECHO OFF
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pushd %~dp0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					REM Command file for Sphinx documentation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if "%SPHINXBUILD%" == "" (
 | 
				
			||||||
 | 
					  set SPHINXBUILD=sphinx-build
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					set SOURCEDIR=source
 | 
				
			||||||
 | 
					set BUILDDIR=build
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if "%1" == "" goto help
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					%SPHINXBUILD% >NUL 2>NUL
 | 
				
			||||||
 | 
					if errorlevel 9009 (
 | 
				
			||||||
 | 
					  echo.
 | 
				
			||||||
 | 
					  echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
 | 
				
			||||||
 | 
					  echo.installed, then set the SPHINXBUILD environment variable to point
 | 
				
			||||||
 | 
					  echo.to the full path of the 'sphinx-build' executable. Alternatively you
 | 
				
			||||||
 | 
					  echo.may add the Sphinx directory to PATH.
 | 
				
			||||||
 | 
					  echo.
 | 
				
			||||||
 | 
					  echo.If you don't have Sphinx installed, grab it from
 | 
				
			||||||
 | 
					  echo.http://sphinx-doc.org/
 | 
				
			||||||
 | 
					  exit /b 1
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
 | 
				
			||||||
 | 
					goto end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:help
 | 
				
			||||||
 | 
					%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					:end
 | 
				
			||||||
 | 
					popd
 | 
				
			||||||
							
								
								
									
										4
									
								
								docs/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								docs/requirements.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
				
			|||||||
 | 
					torch==1.6.0
 | 
				
			||||||
 | 
					matplotlib==3.1.2
 | 
				
			||||||
 | 
					sphinx_rtd_theme==0.5.0
 | 
				
			||||||
 | 
					sphinxcontrib-katex==0.6.1
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								docs/source/_static/img/horizontal-lockup.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/source/_static/img/horizontal-lockup.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 88 KiB  | 
							
								
								
									
										28
									
								
								docs/source/api.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								docs/source/api.rst
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
				
			|||||||
 | 
					.. ProtoFlow API Reference
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ProtoFlow API Reference
 | 
				
			||||||
 | 
					======================================
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Datasets
 | 
				
			||||||
 | 
					--------------------------------------
 | 
				
			||||||
 | 
					.. automodule:: prototorch.datasets
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					   :undoc-members:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Functions
 | 
				
			||||||
 | 
					--------------------------------------
 | 
				
			||||||
 | 
					.. automodule:: prototorch.functions
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					   :undoc-members:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Modules
 | 
				
			||||||
 | 
					--------------------------------------
 | 
				
			||||||
 | 
					.. automodule:: prototorch.modules
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					   :undoc-members:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Utilities
 | 
				
			||||||
 | 
					--------------------------------------
 | 
				
			||||||
 | 
					.. automodule:: prototorch.utils
 | 
				
			||||||
 | 
					   :members:
 | 
				
			||||||
 | 
					   :undoc-members:
 | 
				
			||||||
							
								
								
									
										180
									
								
								docs/source/conf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								docs/source/conf.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,180 @@
 | 
				
			|||||||
 | 
					# Configuration file for the Sphinx documentation builder.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This file only contains a selection of the most common options. For a full
 | 
				
			||||||
 | 
					# list see the documentation:
 | 
				
			||||||
 | 
					# https://www.sphinx-doc.org/en/master/usage/configuration.html
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- Path setup --------------------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# If extensions (or modules to document with autodoc) are in another directory,
 | 
				
			||||||
 | 
					# add these directories to sys.path here. If the directory is relative to the
 | 
				
			||||||
 | 
					# documentation root, use os.path.abspath to make it absolute, like shown here.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					sys.path.insert(0, os.path.abspath("../../"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import sphinx_rtd_theme
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- Project information -----------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					project = "ProtoTorch"
 | 
				
			||||||
 | 
					copyright = "2021, Jensun Ravichandran"
 | 
				
			||||||
 | 
					author = "Jensun Ravichandran"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The full version, including alpha/beta/rc tags
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					release = "0.2.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- General configuration ---------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# If your documentation needs a minimal Sphinx version, state it here.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					needs_sphinx = "1.6"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Add any Sphinx extension module names here, as strings. They can be
 | 
				
			||||||
 | 
					# extensions coming with Sphinx (named "sphinx.ext.*") or your custom
 | 
				
			||||||
 | 
					# ones.
 | 
				
			||||||
 | 
					extensions = [
 | 
				
			||||||
 | 
					    "recommonmark",
 | 
				
			||||||
 | 
					    "sphinx.ext.autodoc",
 | 
				
			||||||
 | 
					    "sphinx.ext.autosummary",
 | 
				
			||||||
 | 
					    "sphinx.ext.doctest",
 | 
				
			||||||
 | 
					    "sphinx.ext.intersphinx",
 | 
				
			||||||
 | 
					    "sphinx.ext.todo",
 | 
				
			||||||
 | 
					    "sphinx.ext.coverage",
 | 
				
			||||||
 | 
					    "sphinx.ext.napoleon",
 | 
				
			||||||
 | 
					    "sphinx.ext.viewcode",
 | 
				
			||||||
 | 
					    "sphinx_rtd_theme",
 | 
				
			||||||
 | 
					    "sphinxcontrib.katex",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# katex_prerender = True
 | 
				
			||||||
 | 
					katex_prerender = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					napoleon_use_ivar = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Add any paths that contain templates here, relative to this directory.
 | 
				
			||||||
 | 
					templates_path = ["_templates"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The suffix(es) of source filenames.
 | 
				
			||||||
 | 
					# You can specify multiple suffix as a list of string:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					source_suffix = [".rst", ".md"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The master toctree document.
 | 
				
			||||||
 | 
					master_doc = "index"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# List of patterns, relative to source directory, that match files and
 | 
				
			||||||
 | 
					# directories to ignore when looking for source files.
 | 
				
			||||||
 | 
					# This pattern also affects html_static_path and html_extra_path.
 | 
				
			||||||
 | 
					exclude_patterns = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The name of the Pygments (syntax highlighting) style to use. Choose from:
 | 
				
			||||||
 | 
					# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
 | 
				
			||||||
 | 
					#  "monokai", "perldoc", "pastie", "borland", "trac", "native", "fruity", "bw",
 | 
				
			||||||
 | 
					#  "vim", "vs", "tango", "rrt", "xcode", "igor", "paraiso-light", "paraiso-dark",
 | 
				
			||||||
 | 
					#  "lovelace", "algol", "algol_nu", "arduino", "rainbo w_dash", "abap",
 | 
				
			||||||
 | 
					#  "solarized-dark", "solarized-light", "sas", "stata", "stata-light",
 | 
				
			||||||
 | 
					#  "stata-dark", "inkpot"]
 | 
				
			||||||
 | 
					pygments_style = "monokai"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# If true, `todo` and `todoList` produce output, else they produce nothing.
 | 
				
			||||||
 | 
					todo_include_todos = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Disable docstring inheritance
 | 
				
			||||||
 | 
					autodoc_inherit_docstrings = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- Options for HTML output -------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The theme to use for HTML and HTML Help pages.  See the documentation for
 | 
				
			||||||
 | 
					# a list of builtin themes.
 | 
				
			||||||
 | 
					# https://sphinx-themes.org/
 | 
				
			||||||
 | 
					html_theme = "sphinx_rtd_theme"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					html_logo = "_static/img/horizontal-lockup.png"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					html_theme_options = {
 | 
				
			||||||
 | 
					    "logo_only": True,
 | 
				
			||||||
 | 
					    "display_version": True,
 | 
				
			||||||
 | 
					    "prev_next_buttons_location": "bottom",
 | 
				
			||||||
 | 
					    "style_external_links": False,
 | 
				
			||||||
 | 
					    "style_nav_header_background": "#ffffff",
 | 
				
			||||||
 | 
					    # Toc options
 | 
				
			||||||
 | 
					    "collapse_navigation": True,
 | 
				
			||||||
 | 
					    "sticky_navigation": True,
 | 
				
			||||||
 | 
					    "navigation_depth": 4,
 | 
				
			||||||
 | 
					    "includehidden": True,
 | 
				
			||||||
 | 
					    "titles_only": False,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Add any paths that contain custom static files (such as style sheets) here,
 | 
				
			||||||
 | 
					# relative to this directory. They are copied after the builtin static files,
 | 
				
			||||||
 | 
					# so a file named "default.css" will overwrite the builtin "default.css".
 | 
				
			||||||
 | 
					html_static_path = ["_static"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					html_css_files = [
 | 
				
			||||||
 | 
					    "https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.css",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- Options for HTMLHelp output ------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Output file base name for HTML help builder.
 | 
				
			||||||
 | 
					htmlhelp_basename = "protoflowdoc"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- Options for LaTeX output ---------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					latex_elements = {
 | 
				
			||||||
 | 
					    # The paper size ("letterpaper" or "a4paper").
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # "papersize": "letterpaper",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # The font size ("10pt", "11pt" or "12pt").
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # "pointsize": "10pt",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Additional stuff for the LaTeX preamble.
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # "preamble": "",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Latex figure (float) alignment
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # "figure_align": "htbp",
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Grouping the document tree into LaTeX files. List of tuples
 | 
				
			||||||
 | 
					# (source start file, target name, title,
 | 
				
			||||||
 | 
					#  author, documentclass [howto, manual, or own class]).
 | 
				
			||||||
 | 
					latex_documents = [
 | 
				
			||||||
 | 
					    (master_doc, "prototorch.tex", "ProtoTorch Documentation",
 | 
				
			||||||
 | 
					     "Jensun Ravichandran", "manual"),
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- Options for manual page output ---------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# One entry per manual page. List of tuples
 | 
				
			||||||
 | 
					# (source start file, name, description, authors, manual section).
 | 
				
			||||||
 | 
					man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- Options for Texinfo output -------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Grouping the document tree into Texinfo files. List of tuples
 | 
				
			||||||
 | 
					# (source start file, target name, title, author,
 | 
				
			||||||
 | 
					#  dir menu entry, description, category)
 | 
				
			||||||
 | 
					texinfo_documents = [
 | 
				
			||||||
 | 
					    (master_doc, "prototorch", "ProtoTorch Documentation", author, "prototorch",
 | 
				
			||||||
 | 
					     "Prototype-based machine learning in PyTorch.",
 | 
				
			||||||
 | 
					     "Miscellaneous"),
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Example configuration for intersphinx: refer to the Python standard library.
 | 
				
			||||||
 | 
					intersphinx_mapping = {
 | 
				
			||||||
 | 
					    "python": ("https://docs.python.org/", None),
 | 
				
			||||||
 | 
					    "numpy": ("https://docs.scipy.org/doc/numpy/", None),
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# -- Options for Epub output ----------------------------------------------
 | 
				
			||||||
 | 
					# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-epub-output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					epub_cover = ()
 | 
				
			||||||
 | 
					version = release
 | 
				
			||||||
							
								
								
									
										22
									
								
								docs/source/index.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								docs/source/index.rst
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
				
			|||||||
 | 
					.. ProtoTorch documentation master file
 | 
				
			||||||
 | 
					   You can adapt this file completely to your liking, but it should at least
 | 
				
			||||||
 | 
					   contain the root `toctree` directive.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					About ProtoTorch
 | 
				
			||||||
 | 
					================
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.. toctree::
 | 
				
			||||||
 | 
					   :hidden:
 | 
				
			||||||
 | 
					   :maxdepth: 3
 | 
				
			||||||
 | 
					   :caption: Contents:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   self
 | 
				
			||||||
 | 
					   api
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge
 | 
				
			||||||
 | 
					research in prototype-based machine learning algorithms.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Indices
 | 
				
			||||||
 | 
					=======
 | 
				
			||||||
 | 
					* :ref:`genindex`
 | 
				
			||||||
 | 
					* :ref:`modindex`
 | 
				
			||||||
							
								
								
									
										162
									
								
								examples/gtlvq_mnist.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								examples/gtlvq_mnist.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,162 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					ProtoTorch GTLVQ example using MNIST data.
 | 
				
			||||||
 | 
					The GTLVQ is placed as an classification model on
 | 
				
			||||||
 | 
					top of a CNN, considered as featurer extractor.
 | 
				
			||||||
 | 
					Initialization of subpsace and prototypes in
 | 
				
			||||||
 | 
					Siamnese fashion
 | 
				
			||||||
 | 
					For more info about GTLVQ see:
 | 
				
			||||||
 | 
					DOI:10.1109/IJCNN.2016.7727534
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torchvision
 | 
				
			||||||
 | 
					from torchvision import transforms
 | 
				
			||||||
 | 
					from prototorch.modules.losses import GLVQLoss
 | 
				
			||||||
 | 
					from prototorch.functions.helper import calculate_prototype_accuracy
 | 
				
			||||||
 | 
					from prototorch.modules.models import GTLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Parameters and options
 | 
				
			||||||
 | 
					n_epochs = 50
 | 
				
			||||||
 | 
					batch_size_train = 64
 | 
				
			||||||
 | 
					batch_size_test = 1000
 | 
				
			||||||
 | 
					learning_rate = 0.1
 | 
				
			||||||
 | 
					momentum = 0.5
 | 
				
			||||||
 | 
					log_interval = 10
 | 
				
			||||||
 | 
					cuda = "cuda:1"
 | 
				
			||||||
 | 
					random_seed = 1
 | 
				
			||||||
 | 
					device = torch.device(cuda if torch.cuda.is_available() else 'cpu')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Configures reproducability
 | 
				
			||||||
 | 
					torch.manual_seed(random_seed)
 | 
				
			||||||
 | 
					np.random.seed(random_seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Prepare and preprocess the data
 | 
				
			||||||
 | 
					train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
 | 
				
			||||||
 | 
					    './files/',
 | 
				
			||||||
 | 
					    train=True,
 | 
				
			||||||
 | 
					    download=True,
 | 
				
			||||||
 | 
					    transform=torchvision.transforms.Compose(
 | 
				
			||||||
 | 
					        [transforms.ToTensor(),
 | 
				
			||||||
 | 
					         transforms.Normalize((0.1307, ), (0.3081, ))])),
 | 
				
			||||||
 | 
					                                           batch_size=batch_size_train,
 | 
				
			||||||
 | 
					                                           shuffle=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
 | 
				
			||||||
 | 
					    './files/',
 | 
				
			||||||
 | 
					    train=False,
 | 
				
			||||||
 | 
					    download=True,
 | 
				
			||||||
 | 
					    transform=torchvision.transforms.Compose(
 | 
				
			||||||
 | 
					        [transforms.ToTensor(),
 | 
				
			||||||
 | 
					         transforms.Normalize((0.1307, ), (0.3081, ))])),
 | 
				
			||||||
 | 
					                                          batch_size=batch_size_test,
 | 
				
			||||||
 | 
					                                          shuffle=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Define the GLVQ model plus appropriate feature extractor
 | 
				
			||||||
 | 
					class CNNGTLVQ(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        num_classes,
 | 
				
			||||||
 | 
					        subspace_data,
 | 
				
			||||||
 | 
					        prototype_data,
 | 
				
			||||||
 | 
					        tangent_projection_type="local",
 | 
				
			||||||
 | 
					        prototypes_per_class=2,
 | 
				
			||||||
 | 
					        bottleneck_dim=128,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super(CNNGTLVQ, self).__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #Feature Extractor - Simple CNN
 | 
				
			||||||
 | 
					        self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
 | 
				
			||||||
 | 
					                                nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
 | 
				
			||||||
 | 
					                                nn.MaxPool2d(2), nn.Dropout(0.25),
 | 
				
			||||||
 | 
					                                nn.Flatten(), nn.Linear(9216, bottleneck_dim),
 | 
				
			||||||
 | 
					                                nn.Dropout(0.5), nn.LeakyReLU(),
 | 
				
			||||||
 | 
					                                nn.LayerNorm(bottleneck_dim))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Forward pass of subspace and prototype initialization data through feature extractor
 | 
				
			||||||
 | 
					        subspace_data = self.fe(subspace_data)
 | 
				
			||||||
 | 
					        prototype_data[0] = self.fe(prototype_data[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Initialization of GTLVQ
 | 
				
			||||||
 | 
					        self.gtlvq = GTLVQ(num_classes,
 | 
				
			||||||
 | 
					                           subspace_data,
 | 
				
			||||||
 | 
					                           prototype_data,
 | 
				
			||||||
 | 
					                           tangent_projection_type=tangent_projection_type,
 | 
				
			||||||
 | 
					                           feature_dim=bottleneck_dim,
 | 
				
			||||||
 | 
					                           prototypes_per_class=prototypes_per_class)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        # Feature Extraction
 | 
				
			||||||
 | 
					        x = self.fe(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # GTLVQ Forward pass
 | 
				
			||||||
 | 
					        dis = self.gtlvq(x)
 | 
				
			||||||
 | 
					        return dis
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Get init data
 | 
				
			||||||
 | 
					subspace_data = torch.cat(
 | 
				
			||||||
 | 
					    [next(iter(train_loader))[0],
 | 
				
			||||||
 | 
					     next(iter(test_loader))[0]])
 | 
				
			||||||
 | 
					prototype_data = next(iter(train_loader))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Build the CNN GTLVQ  model
 | 
				
			||||||
 | 
					model = CNNGTLVQ(10,
 | 
				
			||||||
 | 
					                 subspace_data,
 | 
				
			||||||
 | 
					                 prototype_data,
 | 
				
			||||||
 | 
					                 tangent_projection_type="local",
 | 
				
			||||||
 | 
					                 bottleneck_dim=128).to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Optimize using SGD optimizer from `torch.optim`
 | 
				
			||||||
 | 
					optimizer = torch.optim.Adam([{
 | 
				
			||||||
 | 
					    'params': model.fe.parameters()
 | 
				
			||||||
 | 
					}, {
 | 
				
			||||||
 | 
					    'params': model.gtlvq.parameters()
 | 
				
			||||||
 | 
					}],
 | 
				
			||||||
 | 
					                             lr=learning_rate)
 | 
				
			||||||
 | 
					criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Training loop
 | 
				
			||||||
 | 
					for epoch in range(n_epochs):
 | 
				
			||||||
 | 
					    for batch_idx, (x_train, y_train) in enumerate(train_loader):
 | 
				
			||||||
 | 
					        model.train()
 | 
				
			||||||
 | 
					        x_train, y_train = x_train.to(device), y_train.to(device)
 | 
				
			||||||
 | 
					        optimizer.zero_grad()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        distances = model(x_train)
 | 
				
			||||||
 | 
					        plabels = model.gtlvq.cls.prototype_labels.to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Compute loss.
 | 
				
			||||||
 | 
					        loss = criterion([distances, plabels], y_train)
 | 
				
			||||||
 | 
					        loss.backward()
 | 
				
			||||||
 | 
					        optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
 | 
				
			||||||
 | 
					        model.gtlvq.orthogonalize_subspace()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if batch_idx % log_interval == 0:
 | 
				
			||||||
 | 
					            acc = calculate_prototype_accuracy(distances, y_train, plabels)
 | 
				
			||||||
 | 
					            print(
 | 
				
			||||||
 | 
					                f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
 | 
				
			||||||
 | 
					              Train Acc: {acc.item():02.02f}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Test
 | 
				
			||||||
 | 
					    with torch.no_grad():
 | 
				
			||||||
 | 
					        model.eval()
 | 
				
			||||||
 | 
					        correct = 0
 | 
				
			||||||
 | 
					        total = 0
 | 
				
			||||||
 | 
					        for x_test, y_test in test_loader:
 | 
				
			||||||
 | 
					            x_test, y_test = x_test.to(device), y_test.to(device)
 | 
				
			||||||
 | 
					            test_distances = model(torch.tensor(x_test))
 | 
				
			||||||
 | 
					            test_plabels = model.gtlvq.cls.prototype_labels.to(device)
 | 
				
			||||||
 | 
					            i = torch.argmin(test_distances, 1)
 | 
				
			||||||
 | 
					            correct += torch.sum(y_test == test_plabels[i])
 | 
				
			||||||
 | 
					            total += y_test.size(0)
 | 
				
			||||||
 | 
					        print('Accuracy of the network on the test images: %d %%' %
 | 
				
			||||||
 | 
					              (torch.true_divide(correct, total) * 100))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Save the model
 | 
				
			||||||
 | 
					PATH = './glvq_mnist_model.pth'
 | 
				
			||||||
 | 
					torch.save(model.state_dict(), PATH)
 | 
				
			||||||
@@ -3,14 +3,7 @@
 | 
				
			|||||||
# #############################################
 | 
					# #############################################
 | 
				
			||||||
# Core Setup
 | 
					# Core Setup
 | 
				
			||||||
# #############################################
 | 
					# #############################################
 | 
				
			||||||
from importlib.metadata import version, PackageNotFoundError
 | 
					__version__ = "0.2.0"
 | 
				
			||||||
 | 
					 | 
				
			||||||
VERSION_FALLBACK = "uninstalled_version"
 | 
					 | 
				
			||||||
try:
 | 
					 | 
				
			||||||
    __version_core__ = version(__name__)
 | 
					 | 
				
			||||||
except PackageNotFoundError:
 | 
					 | 
				
			||||||
    __version_core__ = VERSION_FALLBACK
 | 
					 | 
				
			||||||
    pass
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch import datasets, functions, modules
 | 
					from prototorch import datasets, functions, modules
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -40,14 +33,14 @@ discovered_plugins = discover_plugins()
 | 
				
			|||||||
locals().update(discovered_plugins)
 | 
					locals().update(discovered_plugins)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Generate combines __version__ and __all__
 | 
					# Generate combines __version__ and __all__
 | 
				
			||||||
__version_plugins__ = "\n".join(
 | 
					version_plugins = "\n".join(
 | 
				
			||||||
    [
 | 
					    [
 | 
				
			||||||
        "- " + name + ": v" + plugin.__version__
 | 
					        "- " + name + ": v" + plugin.__version__
 | 
				
			||||||
        for name, plugin in discovered_plugins.items()
 | 
					        for name, plugin in discovered_plugins.items()
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
if __version_plugins__ != "":
 | 
					if version_plugins != "":
 | 
				
			||||||
    __version_plugins__ = "\nPlugins: \n" + __version_plugins__
 | 
					    version_plugins = "\nPlugins: \n" + version_plugins
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__version__ = "core: v" + __version_core__ + __version_plugins__
 | 
					version = "core: v" + __version__ + version_plugins
 | 
				
			||||||
__all__ = __all_core__ + list(discovered_plugins.keys())
 | 
					__all__ = __all_core__ + list(discovered_plugins.keys())
 | 
				
			||||||
@@ -1,6 +1,8 @@
 | 
				
			|||||||
"""ProtoTorch distance functions."""
 | 
					"""ProtoTorch distance functions."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.functions.helper import equal_int_shape, _int_and_mixed_shape, _check_shapes
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def squared_euclidean_distance(x, y):
 | 
					def squared_euclidean_distance(x, y):
 | 
				
			||||||
@@ -71,5 +73,155 @@ def lomega_distance(x, y, omegas):
 | 
				
			|||||||
    return distances
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
 | 
				
			||||||
 | 
					    r""" Computes an euclidean distanes matrix given two distinct vectors.
 | 
				
			||||||
 | 
					    last dimension must be the vector dimension!
 | 
				
			||||||
 | 
					    compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    x.shape = (number_of_x_vectors, vector_dim)
 | 
				
			||||||
 | 
					    y.shape = (number_of_y_vectors, vector_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    for tensor in [x, y]:
 | 
				
			||||||
 | 
					        if tensor.ndim != 2:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                'The tensor dimension must be two. You provide: tensor.ndim=' +
 | 
				
			||||||
 | 
					                str(tensor.ndim) + '.')
 | 
				
			||||||
 | 
					    if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            'The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]='
 | 
				
			||||||
 | 
					            + str(tuple(x.shape)[1]) + ' and  tuple(y.shape)(y)[1]=' +
 | 
				
			||||||
 | 
					            str(tuple(y.shape)[1]) + '.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    y = torch.transpose(y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    diss = torch.sum(x**2, axis=1,
 | 
				
			||||||
 | 
					                     keepdims=True) - 2 * torch.dot(x, y) + torch.sum(
 | 
				
			||||||
 | 
					                         y**2, axis=0, keepdims=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not squared:
 | 
				
			||||||
 | 
					        if epsilon == 0:
 | 
				
			||||||
 | 
					            diss = torch.sqrt(diss)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            diss = torch.sqrt(torch.max(diss, epsilon))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return diss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
 | 
				
			||||||
 | 
					    r""" Tangent distances based on the tensorflow implementation of Sascha Saralajews
 | 
				
			||||||
 | 
					    For more info about Tangen distances see DOI:10.1109/IJCNN.2016.7727534.
 | 
				
			||||||
 | 
					    The subspaces is always assumed as transposed and must be orthogonal!
 | 
				
			||||||
 | 
					    For local non sparse signals subspaces must be provided!
 | 
				
			||||||
 | 
					    shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
 | 
				
			||||||
 | 
					    shape(protos): proto_number x dim1 x dim2 x ... x dimN
 | 
				
			||||||
 | 
					    shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN)  x prod(projected_atom_shape)
 | 
				
			||||||
 | 
					    subspace should be orthogonalized
 | 
				
			||||||
 | 
					    Pytorch implementation of Sascha Saralajew's tensorflow code.
 | 
				
			||||||
 | 
					    Translation by Christoph Raab
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
 | 
				
			||||||
 | 
					    proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
 | 
				
			||||||
 | 
					    subspace_int_shape = tuple(subspaces.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # check if the shapes are correct
 | 
				
			||||||
 | 
					    _check_shapes(signal_int_shape, proto_int_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    atom_axes = list(range(3, len(signal_int_shape)))
 | 
				
			||||||
 | 
					    # for sparse signals, we use the memory efficient implementation
 | 
				
			||||||
 | 
					    if signal_int_shape[1] == 1:
 | 
				
			||||||
 | 
					        signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(atom_axes) > 1:
 | 
				
			||||||
 | 
					            protos = torch.reshape(protos, [proto_shape[0], -1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if subspaces.ndim == 2:
 | 
				
			||||||
 | 
					            # clean solution without map if the matrix_scope is global
 | 
				
			||||||
 | 
					            projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
 | 
				
			||||||
 | 
					                subspaces, torch.transpose(subspaces))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            projected_signals = torch.dot(signals, projectors)
 | 
				
			||||||
 | 
					            projected_protos = torch.dot(protos, projectors)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            diss = euclidean_distance_matrix(projected_signals,
 | 
				
			||||||
 | 
					                                             projected_protos,
 | 
				
			||||||
 | 
					                                             squared=squared,
 | 
				
			||||||
 | 
					                                             epsilon=epsilon)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            diss = torch.reshape(
 | 
				
			||||||
 | 
					                diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return torch.permute(diss, [0, 2, 1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # no solution without map possible --> memory efficient but slow!
 | 
				
			||||||
 | 
					            projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
 | 
				
			||||||
 | 
					                subspaces,
 | 
				
			||||||
 | 
					                subspaces)  #K.batch_dot(subspaces, subspaces, [2, 2])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            projected_protos = (protos @ subspaces
 | 
				
			||||||
 | 
					                                ).T  #K.batch_dot(projectors, protos, [1, 1]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def projected_norm(projector):
 | 
				
			||||||
 | 
					                return torch.sum(torch.dot(signals, projector)**2, axis=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            diss = torch.transpose(map(projected_norm, projectors)) \
 | 
				
			||||||
 | 
					                    - 2 * torch.dot(signals, projected_protos) \
 | 
				
			||||||
 | 
					                    + torch.sum(projected_protos**2, axis=0, keepdims=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not squared:
 | 
				
			||||||
 | 
					                if epsilon == 0:
 | 
				
			||||||
 | 
					                    diss = torch.sqrt(diss)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    diss = torch.sqrt(torch.max(diss, epsilon))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            diss = torch.reshape(
 | 
				
			||||||
 | 
					                diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return torch.permute(diss, [0, 2, 1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        signals = signals.permute([0, 2, 1] + atom_axes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        diff = signals - protos
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # global tangent space
 | 
				
			||||||
 | 
					        if subspaces.ndim == 2:
 | 
				
			||||||
 | 
					            #Scope Projectors
 | 
				
			||||||
 | 
					            projectors = subspaces  #
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            #Scope: Tangentspace Projections
 | 
				
			||||||
 | 
					            diff = torch.reshape(
 | 
				
			||||||
 | 
					                diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
 | 
				
			||||||
 | 
					            projected_diff = diff @ projectors
 | 
				
			||||||
 | 
					            projected_diff = torch.reshape(
 | 
				
			||||||
 | 
					                projected_diff,
 | 
				
			||||||
 | 
					                (signal_shape[0], signal_shape[2], signal_shape[1]) +
 | 
				
			||||||
 | 
					                signal_shape[3:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            diss = torch.norm(projected_diff, 2, dim=-1)
 | 
				
			||||||
 | 
					            return diss.permute([0, 2, 1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # local tangent spaces
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Scope: Calculate Projectors
 | 
				
			||||||
 | 
					            projectors = subspaces
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Scope: Tangentspace Projections
 | 
				
			||||||
 | 
					            diff = torch.reshape(
 | 
				
			||||||
 | 
					                diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
 | 
				
			||||||
 | 
					            diff = diff.permute([1, 0, 2])
 | 
				
			||||||
 | 
					            projected_diff = torch.bmm(diff, projectors)
 | 
				
			||||||
 | 
					            projected_diff = torch.reshape(
 | 
				
			||||||
 | 
					                projected_diff,
 | 
				
			||||||
 | 
					                (signal_shape[1], signal_shape[0], signal_shape[2]) +
 | 
				
			||||||
 | 
					                signal_shape[3:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            diss = torch.norm(projected_diff, 2, dim=-1)
 | 
				
			||||||
 | 
					            return diss.permute([1, 0, 2]).squeeze(-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Aliases
 | 
					# Aliases
 | 
				
			||||||
sed = squared_euclidean_distance
 | 
					sed = squared_euclidean_distance
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										89
									
								
								prototorch/functions/helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								prototorch/functions/helper.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,89 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def calculate_prototype_accuracy(y_pred, y_true, plabels):
 | 
				
			||||||
 | 
					    """Computes the accuracy of a prototype based model.
 | 
				
			||||||
 | 
					    via Winner-Takes-All rule.
 | 
				
			||||||
 | 
					    Requirement:
 | 
				
			||||||
 | 
					    y_pred.shape == y_true.shape
 | 
				
			||||||
 | 
					    unique(y_pred) in plabels
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    with torch.no_grad():
 | 
				
			||||||
 | 
					        idx = torch.argmin(y_pred, axis=1)
 | 
				
			||||||
 | 
					        return torch.true_divide(torch.sum(y_true == plabels[idx]),
 | 
				
			||||||
 | 
					                                 len(y_pred)) * 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def predict_label(y_pred, plabels):
 | 
				
			||||||
 | 
					    r""" Predicts labels given a prediction of a prototype based model.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    with torch.no_grad():
 | 
				
			||||||
 | 
					        return plabels[torch.argmin(y_pred, 1)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mixed_shape(inputs):
 | 
				
			||||||
 | 
					    if not torch.is_tensor(inputs):
 | 
				
			||||||
 | 
					        raise ValueError('Input must be a tensor.')
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        int_shape = list(inputs.shape)
 | 
				
			||||||
 | 
					        # sometimes int_shape returns mixed integer types
 | 
				
			||||||
 | 
					        int_shape = [int(i) if i is not None else i for i in int_shape]
 | 
				
			||||||
 | 
					        tensor_shape = inputs.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i, s in enumerate(int_shape):
 | 
				
			||||||
 | 
					            if s is None:
 | 
				
			||||||
 | 
					                int_shape[i] = tensor_shape[i]
 | 
				
			||||||
 | 
					        return tuple(int_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def equal_int_shape(shape_1, shape_2):
 | 
				
			||||||
 | 
					    if not isinstance(shape_1,
 | 
				
			||||||
 | 
					                      (tuple, list)) or not isinstance(shape_2, (tuple, list)):
 | 
				
			||||||
 | 
					        raise ValueError('Input shapes must list or tuple.')
 | 
				
			||||||
 | 
					    for shape in [shape_1, shape_2]:
 | 
				
			||||||
 | 
					        if not all([isinstance(x, int) or x is None for x in shape]):
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                'Input shapes must be list or tuple of int and None values.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if len(shape_1) != len(shape_2):
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        for axis, value in enumerate(shape_1):
 | 
				
			||||||
 | 
					            if value is not None and shape_2[axis] not in {value, None}:
 | 
				
			||||||
 | 
					                return False
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _check_shapes(signal_int_shape, proto_int_shape):
 | 
				
			||||||
 | 
					    if len(signal_int_shape) < 4:
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            "The number of signal dimensions must be >=4. You provide: " +
 | 
				
			||||||
 | 
					            str(len(signal_int_shape)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if len(proto_int_shape) < 2:
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            "The number of proto dimensions must be >=2. You provide: " +
 | 
				
			||||||
 | 
					            str(len(proto_int_shape)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]):
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            "The atom shape of signals must be equal protos. You provide: signals.shape[3:]="
 | 
				
			||||||
 | 
					            + str(signal_int_shape[3:]) + " != protos.shape[1:]=" +
 | 
				
			||||||
 | 
					            str(proto_int_shape[1:]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # not a sparse signal
 | 
				
			||||||
 | 
					    if signal_int_shape[1] != 1:
 | 
				
			||||||
 | 
					        if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]):
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                "If the signal is not sparse, the number of prototypes must be equal in signals and "
 | 
				
			||||||
 | 
					                "protos. You provide: " + str(signal_int_shape[1]) + " != " +
 | 
				
			||||||
 | 
					                str(proto_int_shape[0]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _int_and_mixed_shape(tensor):
 | 
				
			||||||
 | 
					    shape = mixed_shape(tensor)
 | 
				
			||||||
 | 
					    int_shape = tuple([i if isinstance(i, int) else None for i in shape])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return shape, int_shape
 | 
				
			||||||
							
								
								
									
										37
									
								
								prototorch/functions/normalization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								prototorch/functions/normalization.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,37 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					from __future__ import print_function
 | 
				
			||||||
 | 
					from __future__ import absolute_import
 | 
				
			||||||
 | 
					from __future__ import division
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def orthogonalization(tensors):
 | 
				
			||||||
 | 
					    r""" Orthogonalization of a given tensor via polar decomposition.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    u, _, v = torch.svd(tensors, compute_uv=True)
 | 
				
			||||||
 | 
					    u_shape = tuple(list(u.shape))
 | 
				
			||||||
 | 
					    v_shape = tuple(list(v.shape))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # reshape to (num x N x M)
 | 
				
			||||||
 | 
					    u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
 | 
				
			||||||
 | 
					    v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    out = u @ v.permute([0, 2, 1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def trace_normalization(tensors):
 | 
				
			||||||
 | 
					    r""" Trace normalization
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    epsilon = torch.tensor([1e-10], dtype=torch.float64)
 | 
				
			||||||
 | 
					    # Scope trace_normalization
 | 
				
			||||||
 | 
					    constant = torch.trace(tensors)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if epsilon != 0:
 | 
				
			||||||
 | 
					        constant = torch.max(constant, epsilon)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return tensors / constant
 | 
				
			||||||
							
								
								
									
										190
									
								
								prototorch/modules/models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								prototorch/modules/models.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,190 @@
 | 
				
			|||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.modules.prototypes import Prototypes1D
 | 
				
			||||||
 | 
					from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
 | 
				
			||||||
 | 
					from prototorch.functions.normalization import orthogonalization
 | 
				
			||||||
 | 
					from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GTLVQ(nn.Module):
 | 
				
			||||||
 | 
					    r""" Generalized Tangent Learning Vector Quantization
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    num_classes: int
 | 
				
			||||||
 | 
					        Number of classes of the given classification problem.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
 | 
				
			||||||
 | 
					        Subspace data for the point approximation, required
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
 | 
				
			||||||
 | 
					        prototype data for initalization of the prototypes used in GTLVQ.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    subspace_size: int (default=256,optional)
 | 
				
			||||||
 | 
					        Subspace dimension of the Projectors. Currently only supported
 | 
				
			||||||
 | 
					        with tagnent_projection_type=global.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tangent_projection_type: string
 | 
				
			||||||
 | 
					        Specifies the tangent projection type
 | 
				
			||||||
 | 
					        options:    local
 | 
				
			||||||
 | 
					                    local_proj
 | 
				
			||||||
 | 
					                    global
 | 
				
			||||||
 | 
					        local: computes the tangent distances without emphasizing projected
 | 
				
			||||||
 | 
					        data. Only distances are available
 | 
				
			||||||
 | 
					        local_proj: computs tangent distances and returns the projected data
 | 
				
			||||||
 | 
					        for further use. Be careful: data is repeated by number of prototypes
 | 
				
			||||||
 | 
					        global: Number of subspaces is set to one and every prototypes
 | 
				
			||||||
 | 
					        uses the same.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    prototypes_per_class: int (default=2,optional)
 | 
				
			||||||
 | 
					    Number of prototypes per class
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    feature_dim: int (default=256)
 | 
				
			||||||
 | 
					    Dimensionality of the feature space specified as integer.
 | 
				
			||||||
 | 
					    Prototype dimension.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Notes
 | 
				
			||||||
 | 
					    -----
 | 
				
			||||||
 | 
					    The GTLVQ [1] is a prototype-based classification learning model. The
 | 
				
			||||||
 | 
					    GTLVQ uses the Tangent-Distances for a local point approximation
 | 
				
			||||||
 | 
					    of an assumed data manifold via prototypial representations.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The GTLVQ requires subspace projectors for transforming the data
 | 
				
			||||||
 | 
					    and prototypes into the affine subspace. Every prototype is
 | 
				
			||||||
 | 
					    equipped with a specific subpspace and represents a point
 | 
				
			||||||
 | 
					    approximation of the assumed manifold.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    In practice prototypes and data are projected on this manifold
 | 
				
			||||||
 | 
					    and pairwise euclidean distance computes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    References
 | 
				
			||||||
 | 
					    ----------
 | 
				
			||||||
 | 
					    .. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
 | 
				
			||||||
 | 
					    in classification based on manifolc. models and its relation
 | 
				
			||||||
 | 
					    to tangent metric learning. In: 2017 International Joint
 | 
				
			||||||
 | 
					    Conference on Neural Networks (IJCNN).
 | 
				
			||||||
 | 
					    Bd. 2017-May : IEEE, 2017, S. 1756–1765
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        num_classes,
 | 
				
			||||||
 | 
					        subspace_data=None,
 | 
				
			||||||
 | 
					        prototype_data=None,
 | 
				
			||||||
 | 
					        subspace_size=256,
 | 
				
			||||||
 | 
					        tangent_projection_type='local',
 | 
				
			||||||
 | 
					        prototypes_per_class=2,
 | 
				
			||||||
 | 
					        feature_dim=256,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super(GTLVQ, self).__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.num_protos = num_classes * prototypes_per_class
 | 
				
			||||||
 | 
					        self.subspace_size = feature_dim if subspace_size is None else subspace_size
 | 
				
			||||||
 | 
					        self.feature_dim = feature_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if subspace_data is None:
 | 
				
			||||||
 | 
					            raise ValueError('Init Data must be specified!')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.tpt = tangent_projection_type
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            if self.tpt == 'local' or self.tpt == 'local_proj':
 | 
				
			||||||
 | 
					                self.init_local_subspace(subspace_data)
 | 
				
			||||||
 | 
					            elif self.tpt == 'global':
 | 
				
			||||||
 | 
					                self.init_gobal_subspace(subspace_data, subspace_size)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.subspaces = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Hypothesis-Margin-Classifier
 | 
				
			||||||
 | 
					        self.cls = Prototypes1D(input_dim=feature_dim,
 | 
				
			||||||
 | 
					                                prototypes_per_class=prototypes_per_class,
 | 
				
			||||||
 | 
					                                nclasses=num_classes,
 | 
				
			||||||
 | 
					                                prototype_initializer='stratified_mean',
 | 
				
			||||||
 | 
					                                data=prototype_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        # Tangent Projection
 | 
				
			||||||
 | 
					        if self.tpt == 'local_proj':
 | 
				
			||||||
 | 
					            x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
 | 
				
			||||||
 | 
					                                                         1).unsqueeze(2)
 | 
				
			||||||
 | 
					            dis, proj_x = self.local_tangent_projection(x_conform)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
 | 
				
			||||||
 | 
					                                    self.feature_dim)
 | 
				
			||||||
 | 
					            return proj_x, dis
 | 
				
			||||||
 | 
					        elif self.tpt == "local":
 | 
				
			||||||
 | 
					            x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
 | 
				
			||||||
 | 
					                                                         1).unsqueeze(2)
 | 
				
			||||||
 | 
					            dis = tangent_distance(x_conform, self.cls.prototypes,
 | 
				
			||||||
 | 
					                                   self.subspaces)
 | 
				
			||||||
 | 
					        elif self.tpt == "gloabl":
 | 
				
			||||||
 | 
					            dis = self.global_tangent_distances(x)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            dis = (x @ self.cls.prototypes.T) / (
 | 
				
			||||||
 | 
					                torch.norm(x, dim=1, keepdim=True) @ torch.norm(
 | 
				
			||||||
 | 
					                    self.cls.prototypes, dim=1, keepdim=True).T)
 | 
				
			||||||
 | 
					        return dis
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_gobal_subspace(self, data, num_subspaces):
 | 
				
			||||||
 | 
					        _, _, v = torch.svd(data)
 | 
				
			||||||
 | 
					        subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
 | 
				
			||||||
 | 
					        subspaces = subspace[:, :num_subspaces]
 | 
				
			||||||
 | 
					        self.subspaces = torch.nn.Parameter(
 | 
				
			||||||
 | 
					            subspaces).clone().detach().requires_grad_(True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_local_subspace(self, data):
 | 
				
			||||||
 | 
					        _, _, v = torch.svd(data)
 | 
				
			||||||
 | 
					        inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
 | 
				
			||||||
 | 
					        subspaces = inital_projector.unsqueeze(0).repeat_interleave(
 | 
				
			||||||
 | 
					            self.num_protos, 0)
 | 
				
			||||||
 | 
					        self.subspaces = torch.nn.Parameter(
 | 
				
			||||||
 | 
					            subspaces).clone().detach().requires_grad_(True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def global_tangent_distances(self, x):
 | 
				
			||||||
 | 
					        # Tangent Projection
 | 
				
			||||||
 | 
					        x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
 | 
				
			||||||
 | 
					        # Euclidean Distance
 | 
				
			||||||
 | 
					        return euclidean_distance_matrix(x, projected_prototypes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def local_tangent_projection(self,
 | 
				
			||||||
 | 
					                                 signals):
 | 
				
			||||||
 | 
					        # Note: subspaces is always assumed as transposed and must be orthogonal!
 | 
				
			||||||
 | 
					        # shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
 | 
				
			||||||
 | 
					        # shape(protos): proto_number x dim1 x dim2 x ... x dimN
 | 
				
			||||||
 | 
					        # shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN)  x prod(projected_atom_shape)
 | 
				
			||||||
 | 
					        # subspace should be orthogonalized
 | 
				
			||||||
 | 
					        # Origin Source Code
 | 
				
			||||||
 | 
					        # Origin Author:
 | 
				
			||||||
 | 
					        protos = self.cls.prototypes
 | 
				
			||||||
 | 
					        subspaces = self.subspaces
 | 
				
			||||||
 | 
					        signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
 | 
				
			||||||
 | 
					        _, proto_int_shape = _int_and_mixed_shape(protos)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # check if the shapes are correct
 | 
				
			||||||
 | 
					        _check_shapes(signal_int_shape, proto_int_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Tangent Data Projections
 | 
				
			||||||
 | 
					        projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
 | 
				
			||||||
 | 
					        data = signals.squeeze(2).permute([1, 0, 2])
 | 
				
			||||||
 | 
					        projected_data = torch.bmm(data, subspaces)
 | 
				
			||||||
 | 
					        projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
 | 
				
			||||||
 | 
					        diff = projected_data - projected_protos
 | 
				
			||||||
 | 
					        projected_diff = torch.reshape(
 | 
				
			||||||
 | 
					            diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
 | 
				
			||||||
 | 
					            signal_shape[3:])
 | 
				
			||||||
 | 
					        diss = torch.norm(projected_diff, 2, dim=-1)
 | 
				
			||||||
 | 
					        return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_parameters(self):
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            "params": self.cls.prototypes,
 | 
				
			||||||
 | 
					        }, {
 | 
				
			||||||
 | 
					            "params": self.subspaces
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def orthogonalize_subspace(self):
 | 
				
			||||||
 | 
					        if self.subspaces is not None:
 | 
				
			||||||
 | 
					            with torch.no_grad():
 | 
				
			||||||
 | 
					                ortho_subpsaces = orthogonalization(
 | 
				
			||||||
 | 
					                    self.subspaces
 | 
				
			||||||
 | 
					                ) if self.tpt == 'global' else torch.nn.init.orthogonal_(
 | 
				
			||||||
 | 
					                    self.subspaces)
 | 
				
			||||||
 | 
					                self.subspaces.copy_(ortho_subpsaces)
 | 
				
			||||||
@@ -14,55 +14,24 @@ class _Prototypes(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _validate_prototype_distribution(self):
 | 
					    def _validate_prototype_distribution(self):
 | 
				
			||||||
        if 0 in self.prototype_distribution:
 | 
					        if 0 in self.prototype_distribution:
 | 
				
			||||||
            warnings.warn('Are you sure about the `0` in '
 | 
					            warnings.warn("Are you sure about the `0` in "
 | 
				
			||||||
                          '`prototype_distribution`?')
 | 
					                          "`prototype_distribution`?")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def extra_repr(self):
 | 
					    def extra_repr(self):
 | 
				
			||||||
        return f'prototypes.shape: {tuple(self.prototypes.shape)}'
 | 
					        return f"prototypes.shape: {tuple(self.prototypes.shape)}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self):
 | 
					    def forward(self):
 | 
				
			||||||
        return self.prototypes, self.prototype_labels
 | 
					        return self.prototypes, self.prototype_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Prototypes1D(_Prototypes):
 | 
					class Prototypes1D(_Prototypes):
 | 
				
			||||||
    r"""Create a learnable set of one-dimensional prototypes.
 | 
					    """Create a learnable set of one-dimensional prototypes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    TODO Complete this doc-string
 | 
					    TODO Complete this doc-string.
 | 
				
			||||||
 | 
					 | 
				
			||||||
    Kwargs:
 | 
					 | 
				
			||||||
        prototypes_per_class: number of prototypes to use per class.
 | 
					 | 
				
			||||||
            Default: ``1``
 | 
					 | 
				
			||||||
        prototype_initializer: prototype initializer.
 | 
					 | 
				
			||||||
            Default: ``'ones'``
 | 
					 | 
				
			||||||
        prototype_distribution: prototype distribution vector.
 | 
					 | 
				
			||||||
            Default: ``None``
 | 
					 | 
				
			||||||
        input_dim: dimension of the incoming data.
 | 
					 | 
				
			||||||
        nclasses: number of classes.
 | 
					 | 
				
			||||||
        data: If set to ``None``, data-dependent initializers will be ignored.
 | 
					 | 
				
			||||||
            Default: ``None``
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Shape:
 | 
					 | 
				
			||||||
        - Input: :math:`(N, H_{in})`
 | 
					 | 
				
			||||||
            where :math:`H_{in} = \text{input_dim}`.
 | 
					 | 
				
			||||||
        - Output: :math:`(N, H_{out})`
 | 
					 | 
				
			||||||
            where :math:`H_{out} = \text{total_prototypes}`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Attributes:
 | 
					 | 
				
			||||||
        prototypes: the learnable weights of the module of shape
 | 
					 | 
				
			||||||
            :math:`(\text{total_prototypes}, \text{prototype_dimension})`.
 | 
					 | 
				
			||||||
        prototype_labels: the non-learnable labels of the prototypes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Examples:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        >>> p = Prototypes1D(input_dim=20, nclasses=10)
 | 
					 | 
				
			||||||
        >>> input = torch.randn(128, 20)
 | 
					 | 
				
			||||||
        >>> output = m(input)
 | 
					 | 
				
			||||||
        >>> print(output.size())
 | 
					 | 
				
			||||||
        torch.Size([20, 10])
 | 
					 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 prototypes_per_class=1,
 | 
					                 prototypes_per_class=1,
 | 
				
			||||||
                 prototype_initializer='ones',
 | 
					                 prototype_initializer="ones",
 | 
				
			||||||
                 prototype_distribution=None,
 | 
					                 prototype_distribution=None,
 | 
				
			||||||
                 data=None,
 | 
					                 data=None,
 | 
				
			||||||
                 dtype=torch.float32,
 | 
					                 dtype=torch.float32,
 | 
				
			||||||
@@ -75,25 +44,25 @@ class Prototypes1D(_Prototypes):
 | 
				
			|||||||
                prototype_distribution = prototype_distribution.tolist()
 | 
					                prototype_distribution = prototype_distribution.tolist()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if data is None:
 | 
					        if data is None:
 | 
				
			||||||
            if 'input_dim' not in kwargs:
 | 
					            if "input_dim" not in kwargs:
 | 
				
			||||||
                raise NameError('`input_dim` required if '
 | 
					                raise NameError("`input_dim` required if "
 | 
				
			||||||
                                'no `data` is provided.')
 | 
					                                "no `data` is provided.")
 | 
				
			||||||
            if prototype_distribution:
 | 
					            if prototype_distribution:
 | 
				
			||||||
                kwargs_nclasses = sum(prototype_distribution)
 | 
					                kwargs_nclasses = sum(prototype_distribution)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                if 'nclasses' not in kwargs:
 | 
					                if "nclasses" not in kwargs:
 | 
				
			||||||
                    raise NameError('`prototype_distribution` required if '
 | 
					                    raise NameError("`prototype_distribution` required if "
 | 
				
			||||||
                                    'both `data` and `nclasses` are not '
 | 
					                                    "both `data` and `nclasses` are not "
 | 
				
			||||||
                                    'provided.')
 | 
					                                    "provided.")
 | 
				
			||||||
                kwargs_nclasses = kwargs.pop('nclasses')
 | 
					                kwargs_nclasses = kwargs.pop("nclasses")
 | 
				
			||||||
            input_dim = kwargs.pop('input_dim')
 | 
					            input_dim = kwargs.pop("input_dim")
 | 
				
			||||||
            if prototype_initializer in [
 | 
					            if prototype_initializer in [
 | 
				
			||||||
                    'stratified_mean', 'stratified_random'
 | 
					                    "stratified_mean", "stratified_random"
 | 
				
			||||||
            ]:
 | 
					            ]:
 | 
				
			||||||
                warnings.warn(
 | 
					                warnings.warn(
 | 
				
			||||||
                    f'`prototype_initializer`: `{prototype_initializer}` '
 | 
					                    f"`prototype_initializer`: `{prototype_initializer}` "
 | 
				
			||||||
                    'requires `data`, but `data` is not provided. '
 | 
					                    "requires `data`, but `data` is not provided. "
 | 
				
			||||||
                    'Using randomly generated data instead.')
 | 
					                    "Using randomly generated data instead.")
 | 
				
			||||||
            x_train = torch.rand(kwargs_nclasses, input_dim)
 | 
					            x_train = torch.rand(kwargs_nclasses, input_dim)
 | 
				
			||||||
            y_train = torch.arange(kwargs_nclasses)
 | 
					            y_train = torch.arange(kwargs_nclasses)
 | 
				
			||||||
            if one_hot_labels:
 | 
					            if one_hot_labels:
 | 
				
			||||||
@@ -106,39 +75,39 @@ class Prototypes1D(_Prototypes):
 | 
				
			|||||||
        nclasses = torch.unique(y_train, dim=-1).shape[-1]
 | 
					        nclasses = torch.unique(y_train, dim=-1).shape[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if nclasses == 1:
 | 
					        if nclasses == 1:
 | 
				
			||||||
            warnings.warn('Are you sure about having one class only?')
 | 
					            warnings.warn("Are you sure about having one class only?")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if x_train.ndim != 2:
 | 
					        if x_train.ndim != 2:
 | 
				
			||||||
            raise ValueError('`data[0].ndim != 2`.')
 | 
					            raise ValueError("`data[0].ndim != 2`.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if y_train.ndim == 2:
 | 
					        if y_train.ndim == 2:
 | 
				
			||||||
            if y_train.shape[1] == 1 and one_hot_labels:
 | 
					            if y_train.shape[1] == 1 and one_hot_labels:
 | 
				
			||||||
                raise ValueError('`one_hot_labels` is set to `True` '
 | 
					                raise ValueError("`one_hot_labels` is set to `True` "
 | 
				
			||||||
                                 'but target labels are not one-hot-encoded.')
 | 
					                                 "but target labels are not one-hot-encoded.")
 | 
				
			||||||
            if y_train.shape[1] != 1 and not one_hot_labels:
 | 
					            if y_train.shape[1] != 1 and not one_hot_labels:
 | 
				
			||||||
                raise ValueError('`one_hot_labels` is set to `False` '
 | 
					                raise ValueError("`one_hot_labels` is set to `False` "
 | 
				
			||||||
                                 'but target labels in `data` '
 | 
					                                 "but target labels in `data` "
 | 
				
			||||||
                                 'are one-hot-encoded.')
 | 
					                                 "are one-hot-encoded.")
 | 
				
			||||||
        if y_train.ndim == 1 and one_hot_labels:
 | 
					        if y_train.ndim == 1 and one_hot_labels:
 | 
				
			||||||
            raise ValueError('`one_hot_labels` is set to `True` '
 | 
					            raise ValueError("`one_hot_labels` is set to `True` "
 | 
				
			||||||
                             'but target labels are not one-hot-encoded.')
 | 
					                             "but target labels are not one-hot-encoded.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Verify input dimension if `input_dim` is provided
 | 
					        # Verify input dimension if `input_dim` is provided
 | 
				
			||||||
        if 'input_dim' in kwargs:
 | 
					        if "input_dim" in kwargs:
 | 
				
			||||||
            input_dim = kwargs.pop('input_dim')
 | 
					            input_dim = kwargs.pop("input_dim")
 | 
				
			||||||
            if input_dim != x_train.shape[1]:
 | 
					            if input_dim != x_train.shape[1]:
 | 
				
			||||||
                raise ValueError(f'Provided `input_dim`={input_dim} does '
 | 
					                raise ValueError(f"Provided `input_dim`={input_dim} does "
 | 
				
			||||||
                                 'not match data dimension '
 | 
					                                 "not match data dimension "
 | 
				
			||||||
                                 f'`data[0].shape[1]`={x_train.shape[1]}')
 | 
					                                 f"`data[0].shape[1]`={x_train.shape[1]}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Verify the number of classes if `nclasses` is provided
 | 
					        # Verify the number of classes if `nclasses` is provided
 | 
				
			||||||
        if 'nclasses' in kwargs:
 | 
					        if "nclasses" in kwargs:
 | 
				
			||||||
            kwargs_nclasses = kwargs.pop('nclasses')
 | 
					            kwargs_nclasses = kwargs.pop("nclasses")
 | 
				
			||||||
            if kwargs_nclasses != nclasses:
 | 
					            if kwargs_nclasses != nclasses:
 | 
				
			||||||
                raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
 | 
					                raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
 | 
				
			||||||
                                 'not match data labels '
 | 
					                                 "not match data labels "
 | 
				
			||||||
                                 '`torch.unique(data[1]).shape[0]`'
 | 
					                                 "`torch.unique(data[1]).shape[0]`"
 | 
				
			||||||
                                 f'={nclasses}')
 | 
					                                 f"={nclasses}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        super().__init__(**kwargs)
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										12
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								setup.py
									
									
									
									
									
								
							@@ -12,13 +12,6 @@ ProtoTorch Core Package
 | 
				
			|||||||
from setuptools import setup
 | 
					from setuptools import setup
 | 
				
			||||||
from setuptools import find_packages
 | 
					from setuptools import find_packages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from pkg_resources import safe_name
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import ast
 | 
					 | 
				
			||||||
import importlib.util
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
PKG_DIR = "prototorch"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
PROJECT_URL = "https://github.com/si-cim/prototorch"
 | 
					PROJECT_URL = "https://github.com/si-cim/prototorch"
 | 
				
			||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
 | 
					DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -49,8 +42,8 @@ TESTS = ["pytest"]
 | 
				
			|||||||
ALL = DOCS + DATASETS + EXAMPLES + TESTS
 | 
					ALL = DOCS + DATASETS + EXAMPLES + TESTS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
setup(
 | 
					setup(
 | 
				
			||||||
    name=safe_name(PKG_DIR),
 | 
					    name="prototorch",
 | 
				
			||||||
    use_scm_version=True,
 | 
					    version="0.2.0",
 | 
				
			||||||
    description="Highly extensible, GPU-supported "
 | 
					    description="Highly extensible, GPU-supported "
 | 
				
			||||||
    "Learning Vector Quantization (LVQ) toolbox "
 | 
					    "Learning Vector Quantization (LVQ) toolbox "
 | 
				
			||||||
    "built using PyTorch and its nn API.",
 | 
					    "built using PyTorch and its nn API.",
 | 
				
			||||||
@@ -62,7 +55,6 @@ setup(
 | 
				
			|||||||
    download_url=DOWNLOAD_URL,
 | 
					    download_url=DOWNLOAD_URL,
 | 
				
			||||||
    license="MIT",
 | 
					    license="MIT",
 | 
				
			||||||
    install_requires=INSTALL_REQUIRES,
 | 
					    install_requires=INSTALL_REQUIRES,
 | 
				
			||||||
    setup_requires=["setuptools_scm"],
 | 
					 | 
				
			||||||
    extras_require={
 | 
					    extras_require={
 | 
				
			||||||
        "docs": DOCS,
 | 
					        "docs": DOCS,
 | 
				
			||||||
        "datasets": DATASETS,
 | 
					        "datasets": DATASETS,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user