Merge version 0.2.0 into feature/plugin-architecture.
This commit is contained in:
commit
c42df6e203
@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.1.1-rc0
|
current_version = 0.2.0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
||||||
@ -19,3 +19,4 @@ values =
|
|||||||
|
|
||||||
[bumpversion:file:./prototorch/__init__.py]
|
[bumpversion:file:./prototorch/__init__.py]
|
||||||
|
|
||||||
|
[bumpversion:file:./docs/source/conf.py]
|
||||||
|
27
.readthedocs.yml
Normal file
27
.readthedocs.yml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# .readthedocs.yml
|
||||||
|
# Read the Docs configuration file
|
||||||
|
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||||
|
|
||||||
|
# Required
|
||||||
|
version: 2
|
||||||
|
|
||||||
|
# Build documentation in the docs/ directory with Sphinx
|
||||||
|
sphinx:
|
||||||
|
configuration: docs/source/conf.py
|
||||||
|
fail_on_warning: true
|
||||||
|
|
||||||
|
# Build documentation with MkDocs
|
||||||
|
# mkdocs:
|
||||||
|
# configuration: mkdocs.yml
|
||||||
|
|
||||||
|
# Optionally build your docs in additional formats such as PDF and ePub
|
||||||
|
formats: all
|
||||||
|
|
||||||
|
# Optionally set the version of Python and requirements required to build your docs
|
||||||
|
python:
|
||||||
|
version: 3.8
|
||||||
|
install:
|
||||||
|
- method: pip
|
||||||
|
path: .
|
||||||
|
extra_requirements:
|
||||||
|
- all
|
18
README.md
18
README.md
@ -45,22 +45,8 @@ pip install -e .[all]
|
|||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
The documentation is available at <https://prototorch.readthedocs.io/en/latest/>
|
The documentation is available at <https://www.prototorch.ml/en/latest/>. Should
|
||||||
|
that link not work try <https://prototorch.readthedocs.io/en/latest/>.
|
||||||
## Usage
|
|
||||||
|
|
||||||
### For researchers
|
|
||||||
ProtoTorch is modular. It is very easy to use the modular pieces provided by
|
|
||||||
ProtoTorch, like the layers, losses, callbacks and metrics to build your own
|
|
||||||
prototype-based(instance-based) models. These pieces blend-in seamlessly with
|
|
||||||
Keras allowing you to mix and match the modules from ProtoFlow with other
|
|
||||||
modules in `torch.nn`.
|
|
||||||
|
|
||||||
### For engineers
|
|
||||||
ProtoTorch comes prepackaged with many popular Learning Vector Quantization
|
|
||||||
(LVQ)-like algorithms in a convenient API. If you would simply like to be able
|
|
||||||
to use those algorithms to train large ML models on a GPU, ProtoTorch lets you
|
|
||||||
do this without requiring a black-belt in high-performance Tensor computing.
|
|
||||||
|
|
||||||
## Bibtex
|
## Bibtex
|
||||||
|
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
# ProtoTorch Releases
|
# ProtoTorch Releases
|
||||||
|
|
||||||
|
## Release 0.2.0
|
||||||
|
|
||||||
|
### Includes
|
||||||
|
- Fixes in example scripts.
|
||||||
|
|
||||||
## Release 0.1.1-dev0
|
## Release 0.1.1-dev0
|
||||||
|
|
||||||
### Includes
|
### Includes
|
||||||
|
20
docs/Makefile
Normal file
20
docs/Makefile
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Minimal makefile for Sphinx documentation
|
||||||
|
#
|
||||||
|
|
||||||
|
# You can set these variables from the command line, and also
|
||||||
|
# from the environment for the first two.
|
||||||
|
SPHINXOPTS ?=
|
||||||
|
SPHINXBUILD ?= python3 -m sphinx
|
||||||
|
SOURCEDIR = source
|
||||||
|
BUILDDIR = build
|
||||||
|
|
||||||
|
# Put it first so that "make" without argument is like "make help".
|
||||||
|
help:
|
||||||
|
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
|
|
||||||
|
.PHONY: help Makefile
|
||||||
|
|
||||||
|
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||||
|
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||||
|
%: Makefile
|
||||||
|
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
35
docs/make.bat
Normal file
35
docs/make.bat
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
@ECHO OFF
|
||||||
|
|
||||||
|
pushd %~dp0
|
||||||
|
|
||||||
|
REM Command file for Sphinx documentation
|
||||||
|
|
||||||
|
if "%SPHINXBUILD%" == "" (
|
||||||
|
set SPHINXBUILD=sphinx-build
|
||||||
|
)
|
||||||
|
set SOURCEDIR=source
|
||||||
|
set BUILDDIR=build
|
||||||
|
|
||||||
|
if "%1" == "" goto help
|
||||||
|
|
||||||
|
%SPHINXBUILD% >NUL 2>NUL
|
||||||
|
if errorlevel 9009 (
|
||||||
|
echo.
|
||||||
|
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||||
|
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||||
|
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||||
|
echo.may add the Sphinx directory to PATH.
|
||||||
|
echo.
|
||||||
|
echo.If you don't have Sphinx installed, grab it from
|
||||||
|
echo.http://sphinx-doc.org/
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
goto end
|
||||||
|
|
||||||
|
:help
|
||||||
|
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
|
||||||
|
:end
|
||||||
|
popd
|
4
docs/requirements.txt
Normal file
4
docs/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
torch==1.6.0
|
||||||
|
matplotlib==3.1.2
|
||||||
|
sphinx_rtd_theme==0.5.0
|
||||||
|
sphinxcontrib-katex==0.6.1
|
BIN
docs/source/_static/img/horizontal-lockup.png
Normal file
BIN
docs/source/_static/img/horizontal-lockup.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 88 KiB |
28
docs/source/api.rst
Normal file
28
docs/source/api.rst
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
.. ProtoFlow API Reference
|
||||||
|
|
||||||
|
ProtoFlow API Reference
|
||||||
|
======================================
|
||||||
|
|
||||||
|
Datasets
|
||||||
|
--------------------------------------
|
||||||
|
.. automodule:: prototorch.datasets
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
|
||||||
|
Functions
|
||||||
|
--------------------------------------
|
||||||
|
.. automodule:: prototorch.functions
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
|
||||||
|
Modules
|
||||||
|
--------------------------------------
|
||||||
|
.. automodule:: prototorch.modules
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
|
||||||
|
Utilities
|
||||||
|
--------------------------------------
|
||||||
|
.. automodule:: prototorch.utils
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
180
docs/source/conf.py
Normal file
180
docs/source/conf.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
# Configuration file for the Sphinx documentation builder.
|
||||||
|
#
|
||||||
|
# This file only contains a selection of the most common options. For a full
|
||||||
|
# list see the documentation:
|
||||||
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||||
|
|
||||||
|
# -- Path setup --------------------------------------------------------------
|
||||||
|
|
||||||
|
# If extensions (or modules to document with autodoc) are in another directory,
|
||||||
|
# add these directories to sys.path here. If the directory is relative to the
|
||||||
|
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||||
|
#
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.abspath("../../"))
|
||||||
|
|
||||||
|
import sphinx_rtd_theme
|
||||||
|
|
||||||
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
|
project = "ProtoTorch"
|
||||||
|
copyright = "2021, Jensun Ravichandran"
|
||||||
|
author = "Jensun Ravichandran"
|
||||||
|
|
||||||
|
# The full version, including alpha/beta/rc tags
|
||||||
|
#
|
||||||
|
release = "0.2.0"
|
||||||
|
|
||||||
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
# If your documentation needs a minimal Sphinx version, state it here.
|
||||||
|
#
|
||||||
|
needs_sphinx = "1.6"
|
||||||
|
|
||||||
|
# Add any Sphinx extension module names here, as strings. They can be
|
||||||
|
# extensions coming with Sphinx (named "sphinx.ext.*") or your custom
|
||||||
|
# ones.
|
||||||
|
extensions = [
|
||||||
|
"recommonmark",
|
||||||
|
"sphinx.ext.autodoc",
|
||||||
|
"sphinx.ext.autosummary",
|
||||||
|
"sphinx.ext.doctest",
|
||||||
|
"sphinx.ext.intersphinx",
|
||||||
|
"sphinx.ext.todo",
|
||||||
|
"sphinx.ext.coverage",
|
||||||
|
"sphinx.ext.napoleon",
|
||||||
|
"sphinx.ext.viewcode",
|
||||||
|
"sphinx_rtd_theme",
|
||||||
|
"sphinxcontrib.katex",
|
||||||
|
]
|
||||||
|
|
||||||
|
# katex_prerender = True
|
||||||
|
katex_prerender = False
|
||||||
|
|
||||||
|
napoleon_use_ivar = True
|
||||||
|
|
||||||
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
|
templates_path = ["_templates"]
|
||||||
|
|
||||||
|
# The suffix(es) of source filenames.
|
||||||
|
# You can specify multiple suffix as a list of string:
|
||||||
|
#
|
||||||
|
source_suffix = [".rst", ".md"]
|
||||||
|
|
||||||
|
# The master toctree document.
|
||||||
|
master_doc = "index"
|
||||||
|
|
||||||
|
# List of patterns, relative to source directory, that match files and
|
||||||
|
# directories to ignore when looking for source files.
|
||||||
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
|
exclude_patterns = []
|
||||||
|
|
||||||
|
# The name of the Pygments (syntax highlighting) style to use. Choose from:
|
||||||
|
# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
|
||||||
|
# "monokai", "perldoc", "pastie", "borland", "trac", "native", "fruity", "bw",
|
||||||
|
# "vim", "vs", "tango", "rrt", "xcode", "igor", "paraiso-light", "paraiso-dark",
|
||||||
|
# "lovelace", "algol", "algol_nu", "arduino", "rainbo w_dash", "abap",
|
||||||
|
# "solarized-dark", "solarized-light", "sas", "stata", "stata-light",
|
||||||
|
# "stata-dark", "inkpot"]
|
||||||
|
pygments_style = "monokai"
|
||||||
|
|
||||||
|
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||||
|
todo_include_todos = True
|
||||||
|
|
||||||
|
# Disable docstring inheritance
|
||||||
|
autodoc_inherit_docstrings = False
|
||||||
|
|
||||||
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
|
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||||
|
# a list of builtin themes.
|
||||||
|
# https://sphinx-themes.org/
|
||||||
|
html_theme = "sphinx_rtd_theme"
|
||||||
|
|
||||||
|
html_logo = "_static/img/horizontal-lockup.png"
|
||||||
|
|
||||||
|
html_theme_options = {
|
||||||
|
"logo_only": True,
|
||||||
|
"display_version": True,
|
||||||
|
"prev_next_buttons_location": "bottom",
|
||||||
|
"style_external_links": False,
|
||||||
|
"style_nav_header_background": "#ffffff",
|
||||||
|
# Toc options
|
||||||
|
"collapse_navigation": True,
|
||||||
|
"sticky_navigation": True,
|
||||||
|
"navigation_depth": 4,
|
||||||
|
"includehidden": True,
|
||||||
|
"titles_only": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add any paths that contain custom static files (such as style sheets) here,
|
||||||
|
# relative to this directory. They are copied after the builtin static files,
|
||||||
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
|
html_static_path = ["_static"]
|
||||||
|
|
||||||
|
html_css_files = [
|
||||||
|
"https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.css",
|
||||||
|
]
|
||||||
|
|
||||||
|
# -- Options for HTMLHelp output ------------------------------------------
|
||||||
|
|
||||||
|
# Output file base name for HTML help builder.
|
||||||
|
htmlhelp_basename = "protoflowdoc"
|
||||||
|
|
||||||
|
# -- Options for LaTeX output ---------------------------------------------
|
||||||
|
|
||||||
|
latex_elements = {
|
||||||
|
# The paper size ("letterpaper" or "a4paper").
|
||||||
|
#
|
||||||
|
# "papersize": "letterpaper",
|
||||||
|
|
||||||
|
# The font size ("10pt", "11pt" or "12pt").
|
||||||
|
#
|
||||||
|
# "pointsize": "10pt",
|
||||||
|
|
||||||
|
# Additional stuff for the LaTeX preamble.
|
||||||
|
#
|
||||||
|
# "preamble": "",
|
||||||
|
|
||||||
|
# Latex figure (float) alignment
|
||||||
|
#
|
||||||
|
# "figure_align": "htbp",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Grouping the document tree into LaTeX files. List of tuples
|
||||||
|
# (source start file, target name, title,
|
||||||
|
# author, documentclass [howto, manual, or own class]).
|
||||||
|
latex_documents = [
|
||||||
|
(master_doc, "prototorch.tex", "ProtoTorch Documentation",
|
||||||
|
"Jensun Ravichandran", "manual"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# -- Options for manual page output ---------------------------------------
|
||||||
|
|
||||||
|
# One entry per manual page. List of tuples
|
||||||
|
# (source start file, name, description, authors, manual section).
|
||||||
|
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)]
|
||||||
|
|
||||||
|
# -- Options for Texinfo output -------------------------------------------
|
||||||
|
|
||||||
|
# Grouping the document tree into Texinfo files. List of tuples
|
||||||
|
# (source start file, target name, title, author,
|
||||||
|
# dir menu entry, description, category)
|
||||||
|
texinfo_documents = [
|
||||||
|
(master_doc, "prototorch", "ProtoTorch Documentation", author, "prototorch",
|
||||||
|
"Prototype-based machine learning in PyTorch.",
|
||||||
|
"Miscellaneous"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Example configuration for intersphinx: refer to the Python standard library.
|
||||||
|
intersphinx_mapping = {
|
||||||
|
"python": ("https://docs.python.org/", None),
|
||||||
|
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
# -- Options for Epub output ----------------------------------------------
|
||||||
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-epub-output
|
||||||
|
|
||||||
|
epub_cover = ()
|
||||||
|
version = release
|
22
docs/source/index.rst
Normal file
22
docs/source/index.rst
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
.. ProtoTorch documentation master file
|
||||||
|
You can adapt this file completely to your liking, but it should at least
|
||||||
|
contain the root `toctree` directive.
|
||||||
|
|
||||||
|
About ProtoTorch
|
||||||
|
================
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:hidden:
|
||||||
|
:maxdepth: 3
|
||||||
|
:caption: Contents:
|
||||||
|
|
||||||
|
self
|
||||||
|
api
|
||||||
|
|
||||||
|
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge
|
||||||
|
research in prototype-based machine learning algorithms.
|
||||||
|
|
||||||
|
Indices
|
||||||
|
=======
|
||||||
|
* :ref:`genindex`
|
||||||
|
* :ref:`modindex`
|
162
examples/gtlvq_mnist.py
Normal file
162
examples/gtlvq_mnist.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
ProtoTorch GTLVQ example using MNIST data.
|
||||||
|
The GTLVQ is placed as an classification model on
|
||||||
|
top of a CNN, considered as featurer extractor.
|
||||||
|
Initialization of subpsace and prototypes in
|
||||||
|
Siamnese fashion
|
||||||
|
For more info about GTLVQ see:
|
||||||
|
DOI:10.1109/IJCNN.2016.7727534
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
from torchvision import transforms
|
||||||
|
from prototorch.modules.losses import GLVQLoss
|
||||||
|
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||||
|
from prototorch.modules.models import GTLVQ
|
||||||
|
|
||||||
|
# Parameters and options
|
||||||
|
n_epochs = 50
|
||||||
|
batch_size_train = 64
|
||||||
|
batch_size_test = 1000
|
||||||
|
learning_rate = 0.1
|
||||||
|
momentum = 0.5
|
||||||
|
log_interval = 10
|
||||||
|
cuda = "cuda:1"
|
||||||
|
random_seed = 1
|
||||||
|
device = torch.device(cuda if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# Configures reproducability
|
||||||
|
torch.manual_seed(random_seed)
|
||||||
|
np.random.seed(random_seed)
|
||||||
|
|
||||||
|
# Prepare and preprocess the data
|
||||||
|
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
||||||
|
'./files/',
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
transform=torchvision.transforms.Compose(
|
||||||
|
[transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
||||||
|
batch_size=batch_size_train,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
||||||
|
'./files/',
|
||||||
|
train=False,
|
||||||
|
download=True,
|
||||||
|
transform=torchvision.transforms.Compose(
|
||||||
|
[transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
||||||
|
batch_size=batch_size_test,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Define the GLVQ model plus appropriate feature extractor
|
||||||
|
class CNNGTLVQ(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes,
|
||||||
|
subspace_data,
|
||||||
|
prototype_data,
|
||||||
|
tangent_projection_type="local",
|
||||||
|
prototypes_per_class=2,
|
||||||
|
bottleneck_dim=128,
|
||||||
|
):
|
||||||
|
super(CNNGTLVQ, self).__init__()
|
||||||
|
|
||||||
|
#Feature Extractor - Simple CNN
|
||||||
|
self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
|
||||||
|
nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
|
||||||
|
nn.MaxPool2d(2), nn.Dropout(0.25),
|
||||||
|
nn.Flatten(), nn.Linear(9216, bottleneck_dim),
|
||||||
|
nn.Dropout(0.5), nn.LeakyReLU(),
|
||||||
|
nn.LayerNorm(bottleneck_dim))
|
||||||
|
|
||||||
|
# Forward pass of subspace and prototype initialization data through feature extractor
|
||||||
|
subspace_data = self.fe(subspace_data)
|
||||||
|
prototype_data[0] = self.fe(prototype_data[0])
|
||||||
|
|
||||||
|
# Initialization of GTLVQ
|
||||||
|
self.gtlvq = GTLVQ(num_classes,
|
||||||
|
subspace_data,
|
||||||
|
prototype_data,
|
||||||
|
tangent_projection_type=tangent_projection_type,
|
||||||
|
feature_dim=bottleneck_dim,
|
||||||
|
prototypes_per_class=prototypes_per_class)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Feature Extraction
|
||||||
|
x = self.fe(x)
|
||||||
|
|
||||||
|
# GTLVQ Forward pass
|
||||||
|
dis = self.gtlvq(x)
|
||||||
|
return dis
|
||||||
|
|
||||||
|
|
||||||
|
# Get init data
|
||||||
|
subspace_data = torch.cat(
|
||||||
|
[next(iter(train_loader))[0],
|
||||||
|
next(iter(test_loader))[0]])
|
||||||
|
prototype_data = next(iter(train_loader))
|
||||||
|
|
||||||
|
# Build the CNN GTLVQ model
|
||||||
|
model = CNNGTLVQ(10,
|
||||||
|
subspace_data,
|
||||||
|
prototype_data,
|
||||||
|
tangent_projection_type="local",
|
||||||
|
bottleneck_dim=128).to(device)
|
||||||
|
|
||||||
|
# Optimize using SGD optimizer from `torch.optim`
|
||||||
|
optimizer = torch.optim.Adam([{
|
||||||
|
'params': model.fe.parameters()
|
||||||
|
}, {
|
||||||
|
'params': model.gtlvq.parameters()
|
||||||
|
}],
|
||||||
|
lr=learning_rate)
|
||||||
|
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for epoch in range(n_epochs):
|
||||||
|
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
||||||
|
model.train()
|
||||||
|
x_train, y_train = x_train.to(device), y_train.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
distances = model(x_train)
|
||||||
|
plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||||
|
|
||||||
|
# Compute loss.
|
||||||
|
loss = criterion([distances, plabels], y_train)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
|
||||||
|
model.gtlvq.orthogonalize_subspace()
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
||||||
|
print(
|
||||||
|
f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||||
|
Train Acc: {acc.item():02.02f}')
|
||||||
|
|
||||||
|
# Test
|
||||||
|
with torch.no_grad():
|
||||||
|
model.eval()
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for x_test, y_test in test_loader:
|
||||||
|
x_test, y_test = x_test.to(device), y_test.to(device)
|
||||||
|
test_distances = model(torch.tensor(x_test))
|
||||||
|
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||||
|
i = torch.argmin(test_distances, 1)
|
||||||
|
correct += torch.sum(y_test == test_plabels[i])
|
||||||
|
total += y_test.size(0)
|
||||||
|
print('Accuracy of the network on the test images: %d %%' %
|
||||||
|
(torch.true_divide(correct, total) * 100))
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
PATH = './glvq_mnist_model.pth'
|
||||||
|
torch.save(model.state_dict(), PATH)
|
@ -3,14 +3,7 @@
|
|||||||
# #############################################
|
# #############################################
|
||||||
# Core Setup
|
# Core Setup
|
||||||
# #############################################
|
# #############################################
|
||||||
from importlib.metadata import version, PackageNotFoundError
|
__version__ = "0.2.0"
|
||||||
|
|
||||||
VERSION_FALLBACK = "uninstalled_version"
|
|
||||||
try:
|
|
||||||
__version_core__ = version(__name__)
|
|
||||||
except PackageNotFoundError:
|
|
||||||
__version_core__ = VERSION_FALLBACK
|
|
||||||
pass
|
|
||||||
|
|
||||||
from prototorch import datasets, functions, modules
|
from prototorch import datasets, functions, modules
|
||||||
|
|
||||||
@ -40,14 +33,14 @@ discovered_plugins = discover_plugins()
|
|||||||
locals().update(discovered_plugins)
|
locals().update(discovered_plugins)
|
||||||
|
|
||||||
# Generate combines __version__ and __all__
|
# Generate combines __version__ and __all__
|
||||||
__version_plugins__ = "\n".join(
|
version_plugins = "\n".join(
|
||||||
[
|
[
|
||||||
"- " + name + ": v" + plugin.__version__
|
"- " + name + ": v" + plugin.__version__
|
||||||
for name, plugin in discovered_plugins.items()
|
for name, plugin in discovered_plugins.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if __version_plugins__ != "":
|
if version_plugins != "":
|
||||||
__version_plugins__ = "\nPlugins: \n" + __version_plugins__
|
version_plugins = "\nPlugins: \n" + version_plugins
|
||||||
|
|
||||||
__version__ = "core: v" + __version_core__ + __version_plugins__
|
version = "core: v" + __version__ + version_plugins
|
||||||
__all__ = __all_core__ + list(discovered_plugins.keys())
|
__all__ = __all_core__ + list(discovered_plugins.keys())
|
@ -1,6 +1,8 @@
|
|||||||
"""ProtoTorch distance functions."""
|
"""ProtoTorch distance functions."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from prototorch.functions.helper import equal_int_shape, _int_and_mixed_shape, _check_shapes
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def squared_euclidean_distance(x, y):
|
def squared_euclidean_distance(x, y):
|
||||||
@ -71,5 +73,155 @@ def lomega_distance(x, y, omegas):
|
|||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
||||||
|
r""" Computes an euclidean distanes matrix given two distinct vectors.
|
||||||
|
last dimension must be the vector dimension!
|
||||||
|
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
||||||
|
|
||||||
|
x.shape = (number_of_x_vectors, vector_dim)
|
||||||
|
y.shape = (number_of_y_vectors, vector_dim)
|
||||||
|
|
||||||
|
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
||||||
|
"""
|
||||||
|
for tensor in [x, y]:
|
||||||
|
if tensor.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
'The tensor dimension must be two. You provide: tensor.ndim=' +
|
||||||
|
str(tensor.ndim) + '.')
|
||||||
|
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||||
|
raise ValueError(
|
||||||
|
'The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]='
|
||||||
|
+ str(tuple(x.shape)[1]) + ' and tuple(y.shape)(y)[1]=' +
|
||||||
|
str(tuple(y.shape)[1]) + '.')
|
||||||
|
|
||||||
|
y = torch.transpose(y)
|
||||||
|
|
||||||
|
diss = torch.sum(x**2, axis=1,
|
||||||
|
keepdims=True) - 2 * torch.dot(x, y) + torch.sum(
|
||||||
|
y**2, axis=0, keepdims=True)
|
||||||
|
|
||||||
|
if not squared:
|
||||||
|
if epsilon == 0:
|
||||||
|
diss = torch.sqrt(diss)
|
||||||
|
else:
|
||||||
|
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||||
|
|
||||||
|
return diss
|
||||||
|
|
||||||
|
|
||||||
|
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
||||||
|
r""" Tangent distances based on the tensorflow implementation of Sascha Saralajews
|
||||||
|
For more info about Tangen distances see DOI:10.1109/IJCNN.2016.7727534.
|
||||||
|
The subspaces is always assumed as transposed and must be orthogonal!
|
||||||
|
For local non sparse signals subspaces must be provided!
|
||||||
|
shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||||
|
shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||||
|
shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||||
|
subspace should be orthogonalized
|
||||||
|
Pytorch implementation of Sascha Saralajew's tensorflow code.
|
||||||
|
Translation by Christoph Raab
|
||||||
|
"""
|
||||||
|
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||||
|
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
||||||
|
subspace_int_shape = tuple(subspaces.shape)
|
||||||
|
|
||||||
|
# check if the shapes are correct
|
||||||
|
_check_shapes(signal_int_shape, proto_int_shape)
|
||||||
|
|
||||||
|
atom_axes = list(range(3, len(signal_int_shape)))
|
||||||
|
# for sparse signals, we use the memory efficient implementation
|
||||||
|
if signal_int_shape[1] == 1:
|
||||||
|
signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
|
||||||
|
|
||||||
|
if len(atom_axes) > 1:
|
||||||
|
protos = torch.reshape(protos, [proto_shape[0], -1])
|
||||||
|
|
||||||
|
if subspaces.ndim == 2:
|
||||||
|
# clean solution without map if the matrix_scope is global
|
||||||
|
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||||
|
subspaces, torch.transpose(subspaces))
|
||||||
|
|
||||||
|
projected_signals = torch.dot(signals, projectors)
|
||||||
|
projected_protos = torch.dot(protos, projectors)
|
||||||
|
|
||||||
|
diss = euclidean_distance_matrix(projected_signals,
|
||||||
|
projected_protos,
|
||||||
|
squared=squared,
|
||||||
|
epsilon=epsilon)
|
||||||
|
|
||||||
|
diss = torch.reshape(
|
||||||
|
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||||
|
|
||||||
|
return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
# no solution without map possible --> memory efficient but slow!
|
||||||
|
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||||
|
subspaces,
|
||||||
|
subspaces) #K.batch_dot(subspaces, subspaces, [2, 2])
|
||||||
|
|
||||||
|
projected_protos = (protos @ subspaces
|
||||||
|
).T #K.batch_dot(projectors, protos, [1, 1]))
|
||||||
|
|
||||||
|
def projected_norm(projector):
|
||||||
|
return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
||||||
|
|
||||||
|
diss = torch.transpose(map(projected_norm, projectors)) \
|
||||||
|
- 2 * torch.dot(signals, projected_protos) \
|
||||||
|
+ torch.sum(projected_protos**2, axis=0, keepdims=True)
|
||||||
|
|
||||||
|
if not squared:
|
||||||
|
if epsilon == 0:
|
||||||
|
diss = torch.sqrt(diss)
|
||||||
|
else:
|
||||||
|
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||||
|
|
||||||
|
diss = torch.reshape(
|
||||||
|
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||||
|
|
||||||
|
return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
|
else:
|
||||||
|
signals = signals.permute([0, 2, 1] + atom_axes)
|
||||||
|
|
||||||
|
diff = signals - protos
|
||||||
|
|
||||||
|
# global tangent space
|
||||||
|
if subspaces.ndim == 2:
|
||||||
|
#Scope Projectors
|
||||||
|
projectors = subspaces #
|
||||||
|
|
||||||
|
#Scope: Tangentspace Projections
|
||||||
|
diff = torch.reshape(
|
||||||
|
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||||
|
projected_diff = diff @ projectors
|
||||||
|
projected_diff = torch.reshape(
|
||||||
|
projected_diff,
|
||||||
|
(signal_shape[0], signal_shape[2], signal_shape[1]) +
|
||||||
|
signal_shape[3:])
|
||||||
|
|
||||||
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
|
return diss.permute([0, 2, 1])
|
||||||
|
|
||||||
|
# local tangent spaces
|
||||||
|
else:
|
||||||
|
# Scope: Calculate Projectors
|
||||||
|
projectors = subspaces
|
||||||
|
|
||||||
|
# Scope: Tangentspace Projections
|
||||||
|
diff = torch.reshape(
|
||||||
|
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||||
|
diff = diff.permute([1, 0, 2])
|
||||||
|
projected_diff = torch.bmm(diff, projectors)
|
||||||
|
projected_diff = torch.reshape(
|
||||||
|
projected_diff,
|
||||||
|
(signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||||
|
signal_shape[3:])
|
||||||
|
|
||||||
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
|
return diss.permute([1, 0, 2]).squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
sed = squared_euclidean_distance
|
sed = squared_euclidean_distance
|
||||||
|
89
prototorch/functions/helper.py
Normal file
89
prototorch/functions/helper.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
||||||
|
"""Computes the accuracy of a prototype based model.
|
||||||
|
via Winner-Takes-All rule.
|
||||||
|
Requirement:
|
||||||
|
y_pred.shape == y_true.shape
|
||||||
|
unique(y_pred) in plabels
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
idx = torch.argmin(y_pred, axis=1)
|
||||||
|
return torch.true_divide(torch.sum(y_true == plabels[idx]),
|
||||||
|
len(y_pred)) * 100
|
||||||
|
|
||||||
|
|
||||||
|
def predict_label(y_pred, plabels):
|
||||||
|
r""" Predicts labels given a prediction of a prototype based model.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
return plabels[torch.argmin(y_pred, 1)]
|
||||||
|
|
||||||
|
|
||||||
|
def mixed_shape(inputs):
|
||||||
|
if not torch.is_tensor(inputs):
|
||||||
|
raise ValueError('Input must be a tensor.')
|
||||||
|
else:
|
||||||
|
int_shape = list(inputs.shape)
|
||||||
|
# sometimes int_shape returns mixed integer types
|
||||||
|
int_shape = [int(i) if i is not None else i for i in int_shape]
|
||||||
|
tensor_shape = inputs.shape
|
||||||
|
|
||||||
|
for i, s in enumerate(int_shape):
|
||||||
|
if s is None:
|
||||||
|
int_shape[i] = tensor_shape[i]
|
||||||
|
return tuple(int_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def equal_int_shape(shape_1, shape_2):
|
||||||
|
if not isinstance(shape_1,
|
||||||
|
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
||||||
|
raise ValueError('Input shapes must list or tuple.')
|
||||||
|
for shape in [shape_1, shape_2]:
|
||||||
|
if not all([isinstance(x, int) or x is None for x in shape]):
|
||||||
|
raise ValueError(
|
||||||
|
'Input shapes must be list or tuple of int and None values.')
|
||||||
|
|
||||||
|
if len(shape_1) != len(shape_2):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
for axis, value in enumerate(shape_1):
|
||||||
|
if value is not None and shape_2[axis] not in {value, None}:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _check_shapes(signal_int_shape, proto_int_shape):
|
||||||
|
if len(signal_int_shape) < 4:
|
||||||
|
raise ValueError(
|
||||||
|
"The number of signal dimensions must be >=4. You provide: " +
|
||||||
|
str(len(signal_int_shape)))
|
||||||
|
|
||||||
|
if len(proto_int_shape) < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"The number of proto dimensions must be >=2. You provide: " +
|
||||||
|
str(len(proto_int_shape)))
|
||||||
|
|
||||||
|
if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]):
|
||||||
|
raise ValueError(
|
||||||
|
"The atom shape of signals must be equal protos. You provide: signals.shape[3:]="
|
||||||
|
+ str(signal_int_shape[3:]) + " != protos.shape[1:]=" +
|
||||||
|
str(proto_int_shape[1:]))
|
||||||
|
|
||||||
|
# not a sparse signal
|
||||||
|
if signal_int_shape[1] != 1:
|
||||||
|
if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]):
|
||||||
|
raise ValueError(
|
||||||
|
"If the signal is not sparse, the number of prototypes must be equal in signals and "
|
||||||
|
"protos. You provide: " + str(signal_int_shape[1]) + " != " +
|
||||||
|
str(proto_int_shape[0]))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _int_and_mixed_shape(tensor):
|
||||||
|
shape = mixed_shape(tensor)
|
||||||
|
int_shape = tuple([i if isinstance(i, int) else None for i in shape])
|
||||||
|
|
||||||
|
return shape, int_shape
|
37
prototorch/functions/normalization.py
Normal file
37
prototorch/functions/normalization.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def orthogonalization(tensors):
|
||||||
|
r""" Orthogonalization of a given tensor via polar decomposition.
|
||||||
|
"""
|
||||||
|
u, _, v = torch.svd(tensors, compute_uv=True)
|
||||||
|
u_shape = tuple(list(u.shape))
|
||||||
|
v_shape = tuple(list(v.shape))
|
||||||
|
|
||||||
|
# reshape to (num x N x M)
|
||||||
|
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
|
||||||
|
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
|
||||||
|
|
||||||
|
out = u @ v.permute([0, 2, 1])
|
||||||
|
|
||||||
|
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def trace_normalization(tensors):
|
||||||
|
r""" Trace normalization
|
||||||
|
"""
|
||||||
|
epsilon = torch.tensor([1e-10], dtype=torch.float64)
|
||||||
|
# Scope trace_normalization
|
||||||
|
constant = torch.trace(tensors)
|
||||||
|
|
||||||
|
if epsilon != 0:
|
||||||
|
constant = torch.max(constant, epsilon)
|
||||||
|
|
||||||
|
return tensors / constant
|
190
prototorch/modules/models.py
Normal file
190
prototorch/modules/models.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
||||||
|
from prototorch.functions.normalization import orthogonalization
|
||||||
|
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||||
|
|
||||||
|
|
||||||
|
class GTLVQ(nn.Module):
|
||||||
|
r""" Generalized Tangent Learning Vector Quantization
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num_classes: int
|
||||||
|
Number of classes of the given classification problem.
|
||||||
|
|
||||||
|
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
|
||||||
|
Subspace data for the point approximation, required
|
||||||
|
|
||||||
|
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
|
||||||
|
prototype data for initalization of the prototypes used in GTLVQ.
|
||||||
|
|
||||||
|
subspace_size: int (default=256,optional)
|
||||||
|
Subspace dimension of the Projectors. Currently only supported
|
||||||
|
with tagnent_projection_type=global.
|
||||||
|
|
||||||
|
tangent_projection_type: string
|
||||||
|
Specifies the tangent projection type
|
||||||
|
options: local
|
||||||
|
local_proj
|
||||||
|
global
|
||||||
|
local: computes the tangent distances without emphasizing projected
|
||||||
|
data. Only distances are available
|
||||||
|
local_proj: computs tangent distances and returns the projected data
|
||||||
|
for further use. Be careful: data is repeated by number of prototypes
|
||||||
|
global: Number of subspaces is set to one and every prototypes
|
||||||
|
uses the same.
|
||||||
|
|
||||||
|
prototypes_per_class: int (default=2,optional)
|
||||||
|
Number of prototypes per class
|
||||||
|
|
||||||
|
feature_dim: int (default=256)
|
||||||
|
Dimensionality of the feature space specified as integer.
|
||||||
|
Prototype dimension.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
The GTLVQ [1] is a prototype-based classification learning model. The
|
||||||
|
GTLVQ uses the Tangent-Distances for a local point approximation
|
||||||
|
of an assumed data manifold via prototypial representations.
|
||||||
|
|
||||||
|
The GTLVQ requires subspace projectors for transforming the data
|
||||||
|
and prototypes into the affine subspace. Every prototype is
|
||||||
|
equipped with a specific subpspace and represents a point
|
||||||
|
approximation of the assumed manifold.
|
||||||
|
|
||||||
|
In practice prototypes and data are projected on this manifold
|
||||||
|
and pairwise euclidean distance computes.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
|
||||||
|
in classification based on manifolc. models and its relation
|
||||||
|
to tangent metric learning. In: 2017 International Joint
|
||||||
|
Conference on Neural Networks (IJCNN).
|
||||||
|
Bd. 2017-May : IEEE, 2017, S. 1756–1765
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes,
|
||||||
|
subspace_data=None,
|
||||||
|
prototype_data=None,
|
||||||
|
subspace_size=256,
|
||||||
|
tangent_projection_type='local',
|
||||||
|
prototypes_per_class=2,
|
||||||
|
feature_dim=256,
|
||||||
|
):
|
||||||
|
super(GTLVQ, self).__init__()
|
||||||
|
|
||||||
|
self.num_protos = num_classes * prototypes_per_class
|
||||||
|
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
|
||||||
|
if subspace_data is None:
|
||||||
|
raise ValueError('Init Data must be specified!')
|
||||||
|
|
||||||
|
self.tpt = tangent_projection_type
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.tpt == 'local' or self.tpt == 'local_proj':
|
||||||
|
self.init_local_subspace(subspace_data)
|
||||||
|
elif self.tpt == 'global':
|
||||||
|
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||||
|
else:
|
||||||
|
self.subspaces = None
|
||||||
|
|
||||||
|
# Hypothesis-Margin-Classifier
|
||||||
|
self.cls = Prototypes1D(input_dim=feature_dim,
|
||||||
|
prototypes_per_class=prototypes_per_class,
|
||||||
|
nclasses=num_classes,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=prototype_data)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Tangent Projection
|
||||||
|
if self.tpt == 'local_proj':
|
||||||
|
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||||
|
1).unsqueeze(2)
|
||||||
|
dis, proj_x = self.local_tangent_projection(x_conform)
|
||||||
|
|
||||||
|
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
||||||
|
self.feature_dim)
|
||||||
|
return proj_x, dis
|
||||||
|
elif self.tpt == "local":
|
||||||
|
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||||
|
1).unsqueeze(2)
|
||||||
|
dis = tangent_distance(x_conform, self.cls.prototypes,
|
||||||
|
self.subspaces)
|
||||||
|
elif self.tpt == "gloabl":
|
||||||
|
dis = self.global_tangent_distances(x)
|
||||||
|
else:
|
||||||
|
dis = (x @ self.cls.prototypes.T) / (
|
||||||
|
torch.norm(x, dim=1, keepdim=True) @ torch.norm(
|
||||||
|
self.cls.prototypes, dim=1, keepdim=True).T)
|
||||||
|
return dis
|
||||||
|
|
||||||
|
def init_gobal_subspace(self, data, num_subspaces):
|
||||||
|
_, _, v = torch.svd(data)
|
||||||
|
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||||
|
subspaces = subspace[:, :num_subspaces]
|
||||||
|
self.subspaces = torch.nn.Parameter(
|
||||||
|
subspaces).clone().detach().requires_grad_(True)
|
||||||
|
|
||||||
|
def init_local_subspace(self, data):
|
||||||
|
_, _, v = torch.svd(data)
|
||||||
|
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||||
|
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
||||||
|
self.num_protos, 0)
|
||||||
|
self.subspaces = torch.nn.Parameter(
|
||||||
|
subspaces).clone().detach().requires_grad_(True)
|
||||||
|
|
||||||
|
def global_tangent_distances(self, x):
|
||||||
|
# Tangent Projection
|
||||||
|
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
|
||||||
|
# Euclidean Distance
|
||||||
|
return euclidean_distance_matrix(x, projected_prototypes)
|
||||||
|
|
||||||
|
def local_tangent_projection(self,
|
||||||
|
signals):
|
||||||
|
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
||||||
|
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||||
|
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||||
|
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||||
|
# subspace should be orthogonalized
|
||||||
|
# Origin Source Code
|
||||||
|
# Origin Author:
|
||||||
|
protos = self.cls.prototypes
|
||||||
|
subspaces = self.subspaces
|
||||||
|
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||||
|
_, proto_int_shape = _int_and_mixed_shape(protos)
|
||||||
|
|
||||||
|
# check if the shapes are correct
|
||||||
|
_check_shapes(signal_int_shape, proto_int_shape)
|
||||||
|
|
||||||
|
# Tangent Data Projections
|
||||||
|
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
|
||||||
|
data = signals.squeeze(2).permute([1, 0, 2])
|
||||||
|
projected_data = torch.bmm(data, subspaces)
|
||||||
|
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
|
||||||
|
diff = projected_data - projected_protos
|
||||||
|
projected_diff = torch.reshape(
|
||||||
|
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||||
|
signal_shape[3:])
|
||||||
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
|
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
|
||||||
|
|
||||||
|
def get_parameters(self):
|
||||||
|
return {
|
||||||
|
"params": self.cls.prototypes,
|
||||||
|
}, {
|
||||||
|
"params": self.subspaces
|
||||||
|
}
|
||||||
|
|
||||||
|
def orthogonalize_subspace(self):
|
||||||
|
if self.subspaces is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
ortho_subpsaces = orthogonalization(
|
||||||
|
self.subspaces
|
||||||
|
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
|
||||||
|
self.subspaces)
|
||||||
|
self.subspaces.copy_(ortho_subpsaces)
|
@ -14,55 +14,24 @@ class _Prototypes(torch.nn.Module):
|
|||||||
|
|
||||||
def _validate_prototype_distribution(self):
|
def _validate_prototype_distribution(self):
|
||||||
if 0 in self.prototype_distribution:
|
if 0 in self.prototype_distribution:
|
||||||
warnings.warn('Are you sure about the `0` in '
|
warnings.warn("Are you sure about the `0` in "
|
||||||
'`prototype_distribution`?')
|
"`prototype_distribution`?")
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
|
return f"prototypes.shape: {tuple(self.prototypes.shape)}"
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return self.prototypes, self.prototype_labels
|
return self.prototypes, self.prototype_labels
|
||||||
|
|
||||||
|
|
||||||
class Prototypes1D(_Prototypes):
|
class Prototypes1D(_Prototypes):
|
||||||
r"""Create a learnable set of one-dimensional prototypes.
|
"""Create a learnable set of one-dimensional prototypes.
|
||||||
|
|
||||||
TODO Complete this doc-string
|
TODO Complete this doc-string.
|
||||||
|
|
||||||
Kwargs:
|
|
||||||
prototypes_per_class: number of prototypes to use per class.
|
|
||||||
Default: ``1``
|
|
||||||
prototype_initializer: prototype initializer.
|
|
||||||
Default: ``'ones'``
|
|
||||||
prototype_distribution: prototype distribution vector.
|
|
||||||
Default: ``None``
|
|
||||||
input_dim: dimension of the incoming data.
|
|
||||||
nclasses: number of classes.
|
|
||||||
data: If set to ``None``, data-dependent initializers will be ignored.
|
|
||||||
Default: ``None``
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
- Input: :math:`(N, H_{in})`
|
|
||||||
where :math:`H_{in} = \text{input_dim}`.
|
|
||||||
- Output: :math:`(N, H_{out})`
|
|
||||||
where :math:`H_{out} = \text{total_prototypes}`.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
prototypes: the learnable weights of the module of shape
|
|
||||||
:math:`(\text{total_prototypes}, \text{prototype_dimension})`.
|
|
||||||
prototype_labels: the non-learnable labels of the prototypes.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
>>> p = Prototypes1D(input_dim=20, nclasses=10)
|
|
||||||
>>> input = torch.randn(128, 20)
|
|
||||||
>>> output = m(input)
|
|
||||||
>>> print(output.size())
|
|
||||||
torch.Size([20, 10])
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='ones',
|
prototype_initializer="ones",
|
||||||
prototype_distribution=None,
|
prototype_distribution=None,
|
||||||
data=None,
|
data=None,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -75,25 +44,25 @@ class Prototypes1D(_Prototypes):
|
|||||||
prototype_distribution = prototype_distribution.tolist()
|
prototype_distribution = prototype_distribution.tolist()
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
if 'input_dim' not in kwargs:
|
if "input_dim" not in kwargs:
|
||||||
raise NameError('`input_dim` required if '
|
raise NameError("`input_dim` required if "
|
||||||
'no `data` is provided.')
|
"no `data` is provided.")
|
||||||
if prototype_distribution:
|
if prototype_distribution:
|
||||||
kwargs_nclasses = sum(prototype_distribution)
|
kwargs_nclasses = sum(prototype_distribution)
|
||||||
else:
|
else:
|
||||||
if 'nclasses' not in kwargs:
|
if "nclasses" not in kwargs:
|
||||||
raise NameError('`prototype_distribution` required if '
|
raise NameError("`prototype_distribution` required if "
|
||||||
'both `data` and `nclasses` are not '
|
"both `data` and `nclasses` are not "
|
||||||
'provided.')
|
"provided.")
|
||||||
kwargs_nclasses = kwargs.pop('nclasses')
|
kwargs_nclasses = kwargs.pop("nclasses")
|
||||||
input_dim = kwargs.pop('input_dim')
|
input_dim = kwargs.pop("input_dim")
|
||||||
if prototype_initializer in [
|
if prototype_initializer in [
|
||||||
'stratified_mean', 'stratified_random'
|
"stratified_mean", "stratified_random"
|
||||||
]:
|
]:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f'`prototype_initializer`: `{prototype_initializer}` '
|
f"`prototype_initializer`: `{prototype_initializer}` "
|
||||||
'requires `data`, but `data` is not provided. '
|
"requires `data`, but `data` is not provided. "
|
||||||
'Using randomly generated data instead.')
|
"Using randomly generated data instead.")
|
||||||
x_train = torch.rand(kwargs_nclasses, input_dim)
|
x_train = torch.rand(kwargs_nclasses, input_dim)
|
||||||
y_train = torch.arange(kwargs_nclasses)
|
y_train = torch.arange(kwargs_nclasses)
|
||||||
if one_hot_labels:
|
if one_hot_labels:
|
||||||
@ -106,39 +75,39 @@ class Prototypes1D(_Prototypes):
|
|||||||
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
||||||
|
|
||||||
if nclasses == 1:
|
if nclasses == 1:
|
||||||
warnings.warn('Are you sure about having one class only?')
|
warnings.warn("Are you sure about having one class only?")
|
||||||
|
|
||||||
if x_train.ndim != 2:
|
if x_train.ndim != 2:
|
||||||
raise ValueError('`data[0].ndim != 2`.')
|
raise ValueError("`data[0].ndim != 2`.")
|
||||||
|
|
||||||
if y_train.ndim == 2:
|
if y_train.ndim == 2:
|
||||||
if y_train.shape[1] == 1 and one_hot_labels:
|
if y_train.shape[1] == 1 and one_hot_labels:
|
||||||
raise ValueError('`one_hot_labels` is set to `True` '
|
raise ValueError("`one_hot_labels` is set to `True` "
|
||||||
'but target labels are not one-hot-encoded.')
|
"but target labels are not one-hot-encoded.")
|
||||||
if y_train.shape[1] != 1 and not one_hot_labels:
|
if y_train.shape[1] != 1 and not one_hot_labels:
|
||||||
raise ValueError('`one_hot_labels` is set to `False` '
|
raise ValueError("`one_hot_labels` is set to `False` "
|
||||||
'but target labels in `data` '
|
"but target labels in `data` "
|
||||||
'are one-hot-encoded.')
|
"are one-hot-encoded.")
|
||||||
if y_train.ndim == 1 and one_hot_labels:
|
if y_train.ndim == 1 and one_hot_labels:
|
||||||
raise ValueError('`one_hot_labels` is set to `True` '
|
raise ValueError("`one_hot_labels` is set to `True` "
|
||||||
'but target labels are not one-hot-encoded.')
|
"but target labels are not one-hot-encoded.")
|
||||||
|
|
||||||
# Verify input dimension if `input_dim` is provided
|
# Verify input dimension if `input_dim` is provided
|
||||||
if 'input_dim' in kwargs:
|
if "input_dim" in kwargs:
|
||||||
input_dim = kwargs.pop('input_dim')
|
input_dim = kwargs.pop("input_dim")
|
||||||
if input_dim != x_train.shape[1]:
|
if input_dim != x_train.shape[1]:
|
||||||
raise ValueError(f'Provided `input_dim`={input_dim} does '
|
raise ValueError(f"Provided `input_dim`={input_dim} does "
|
||||||
'not match data dimension '
|
"not match data dimension "
|
||||||
f'`data[0].shape[1]`={x_train.shape[1]}')
|
f"`data[0].shape[1]`={x_train.shape[1]}")
|
||||||
|
|
||||||
# Verify the number of classes if `nclasses` is provided
|
# Verify the number of classes if `nclasses` is provided
|
||||||
if 'nclasses' in kwargs:
|
if "nclasses" in kwargs:
|
||||||
kwargs_nclasses = kwargs.pop('nclasses')
|
kwargs_nclasses = kwargs.pop("nclasses")
|
||||||
if kwargs_nclasses != nclasses:
|
if kwargs_nclasses != nclasses:
|
||||||
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
|
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
|
||||||
'not match data labels '
|
"not match data labels "
|
||||||
'`torch.unique(data[1]).shape[0]`'
|
"`torch.unique(data[1]).shape[0]`"
|
||||||
f'={nclasses}')
|
f"={nclasses}")
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
12
setup.py
12
setup.py
@ -12,13 +12,6 @@ ProtoTorch Core Package
|
|||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
from setuptools import find_packages
|
from setuptools import find_packages
|
||||||
|
|
||||||
from pkg_resources import safe_name
|
|
||||||
|
|
||||||
import ast
|
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
PKG_DIR = "prototorch"
|
|
||||||
|
|
||||||
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
||||||
|
|
||||||
@ -49,8 +42,8 @@ TESTS = ["pytest"]
|
|||||||
ALL = DOCS + DATASETS + EXAMPLES + TESTS
|
ALL = DOCS + DATASETS + EXAMPLES + TESTS
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=safe_name(PKG_DIR),
|
name="prototorch",
|
||||||
use_scm_version=True,
|
version="0.2.0",
|
||||||
description="Highly extensible, GPU-supported "
|
description="Highly extensible, GPU-supported "
|
||||||
"Learning Vector Quantization (LVQ) toolbox "
|
"Learning Vector Quantization (LVQ) toolbox "
|
||||||
"built using PyTorch and its nn API.",
|
"built using PyTorch and its nn API.",
|
||||||
@ -62,7 +55,6 @@ setup(
|
|||||||
download_url=DOWNLOAD_URL,
|
download_url=DOWNLOAD_URL,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
install_requires=INSTALL_REQUIRES,
|
install_requires=INSTALL_REQUIRES,
|
||||||
setup_requires=["setuptools_scm"],
|
|
||||||
extras_require={
|
extras_require={
|
||||||
"docs": DOCS,
|
"docs": DOCS,
|
||||||
"datasets": DATASETS,
|
"datasets": DATASETS,
|
||||||
|
Loading…
Reference in New Issue
Block a user