Merge version 0.2.0 into feature/plugin-architecture.

This commit is contained in:
Alexander Engelsberger 2021-04-19 16:43:15 +02:00
commit c42df6e203
19 changed files with 1001 additions and 109 deletions

View File

@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.1-rc0
current_version = 0.2.0
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
@ -19,3 +19,4 @@ values =
[bumpversion:file:./prototorch/__init__.py]
[bumpversion:file:./docs/source/conf.py]

27
.readthedocs.yml Normal file
View File

@ -0,0 +1,27 @@
# .readthedocs.yml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
fail_on_warning: true
# Build documentation with MkDocs
# mkdocs:
# configuration: mkdocs.yml
# Optionally build your docs in additional formats such as PDF and ePub
formats: all
# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.8
install:
- method: pip
path: .
extra_requirements:
- all

View File

@ -45,22 +45,8 @@ pip install -e .[all]
## Documentation
The documentation is available at <https://prototorch.readthedocs.io/en/latest/>
## Usage
### For researchers
ProtoTorch is modular. It is very easy to use the modular pieces provided by
ProtoTorch, like the layers, losses, callbacks and metrics to build your own
prototype-based(instance-based) models. These pieces blend-in seamlessly with
Keras allowing you to mix and match the modules from ProtoFlow with other
modules in `torch.nn`.
### For engineers
ProtoTorch comes prepackaged with many popular Learning Vector Quantization
(LVQ)-like algorithms in a convenient API. If you would simply like to be able
to use those algorithms to train large ML models on a GPU, ProtoTorch lets you
do this without requiring a black-belt in high-performance Tensor computing.
The documentation is available at <https://www.prototorch.ml/en/latest/>. Should
that link not work try <https://prototorch.readthedocs.io/en/latest/>.
## Bibtex

View File

@ -1,5 +1,10 @@
# ProtoTorch Releases
## Release 0.2.0
### Includes
- Fixes in example scripts.
## Release 0.1.1-dev0
### Includes

20
docs/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= python3 -m sphinx
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
docs/make.bat Normal file
View File

@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

4
docs/requirements.txt Normal file
View File

@ -0,0 +1,4 @@
torch==1.6.0
matplotlib==3.1.2
sphinx_rtd_theme==0.5.0
sphinxcontrib-katex==0.6.1

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

28
docs/source/api.rst Normal file
View File

@ -0,0 +1,28 @@
.. ProtoFlow API Reference
ProtoFlow API Reference
======================================
Datasets
--------------------------------------
.. automodule:: prototorch.datasets
:members:
:undoc-members:
Functions
--------------------------------------
.. automodule:: prototorch.functions
:members:
:undoc-members:
Modules
--------------------------------------
.. automodule:: prototorch.modules
:members:
:undoc-members:
Utilities
--------------------------------------
.. automodule:: prototorch.utils
:members:
:undoc-members:

180
docs/source/conf.py Normal file
View File

@ -0,0 +1,180 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath("../../"))
import sphinx_rtd_theme
# -- Project information -----------------------------------------------------
project = "ProtoTorch"
copyright = "2021, Jensun Ravichandran"
author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "0.2.0"
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
needs_sphinx = "1.6"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named "sphinx.ext.*") or your custom
# ones.
extensions = [
"recommonmark",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
"sphinx.ext.coverage",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx_rtd_theme",
"sphinxcontrib.katex",
]
# katex_prerender = True
katex_prerender = False
napoleon_use_ivar = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = [".rst", ".md"]
# The master toctree document.
master_doc = "index"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use. Choose from:
# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
# "monokai", "perldoc", "pastie", "borland", "trac", "native", "fruity", "bw",
# "vim", "vs", "tango", "rrt", "xcode", "igor", "paraiso-light", "paraiso-dark",
# "lovelace", "algol", "algol_nu", "arduino", "rainbo w_dash", "abap",
# "solarized-dark", "solarized-light", "sas", "stata", "stata-light",
# "stata-dark", "inkpot"]
pygments_style = "monokai"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
# Disable docstring inheritance
autodoc_inherit_docstrings = False
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
# https://sphinx-themes.org/
html_theme = "sphinx_rtd_theme"
html_logo = "_static/img/horizontal-lockup.png"
html_theme_options = {
"logo_only": True,
"display_version": True,
"prev_next_buttons_location": "bottom",
"style_external_links": False,
"style_nav_header_background": "#ffffff",
# Toc options
"collapse_navigation": True,
"sticky_navigation": True,
"navigation_depth": 4,
"includehidden": True,
"titles_only": False,
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
html_css_files = [
"https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.css",
]
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = "protoflowdoc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ("letterpaper" or "a4paper").
#
# "papersize": "letterpaper",
# The font size ("10pt", "11pt" or "12pt").
#
# "pointsize": "10pt",
# Additional stuff for the LaTeX preamble.
#
# "preamble": "",
# Latex figure (float) alignment
#
# "figure_align": "htbp",
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, "prototorch.tex", "ProtoTorch Documentation",
"Jensun Ravichandran", "manual"),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, "prototorch", "ProtoTorch Documentation", author, "prototorch",
"Prototype-based machine learning in PyTorch.",
"Miscellaneous"),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"python": ("https://docs.python.org/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
}
# -- Options for Epub output ----------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-epub-output
epub_cover = ()
version = release

22
docs/source/index.rst Normal file
View File

@ -0,0 +1,22 @@
.. ProtoTorch documentation master file
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
About ProtoTorch
================
.. toctree::
:hidden:
:maxdepth: 3
:caption: Contents:
self
api
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge
research in prototype-based machine learning algorithms.
Indices
=======
* :ref:`genindex`
* :ref:`modindex`

162
examples/gtlvq_mnist.py Normal file
View File

@ -0,0 +1,162 @@
"""
ProtoTorch GTLVQ example using MNIST data.
The GTLVQ is placed as an classification model on
top of a CNN, considered as featurer extractor.
Initialization of subpsace and prototypes in
Siamnese fashion
For more info about GTLVQ see:
DOI:10.1109/IJCNN.2016.7727534
"""
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from prototorch.modules.losses import GLVQLoss
from prototorch.functions.helper import calculate_prototype_accuracy
from prototorch.modules.models import GTLVQ
# Parameters and options
n_epochs = 50
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.1
momentum = 0.5
log_interval = 10
cuda = "cuda:1"
random_seed = 1
device = torch.device(cuda if torch.cuda.is_available() else 'cpu')
# Configures reproducability
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# Prepare and preprocess the data
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
'./files/',
train=True,
download=True,
transform=torchvision.transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))])),
batch_size=batch_size_train,
shuffle=True)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
'./files/',
train=False,
download=True,
transform=torchvision.transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))])),
batch_size=batch_size_test,
shuffle=True)
# Define the GLVQ model plus appropriate feature extractor
class CNNGTLVQ(torch.nn.Module):
def __init__(
self,
num_classes,
subspace_data,
prototype_data,
tangent_projection_type="local",
prototypes_per_class=2,
bottleneck_dim=128,
):
super(CNNGTLVQ, self).__init__()
#Feature Extractor - Simple CNN
self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
nn.MaxPool2d(2), nn.Dropout(0.25),
nn.Flatten(), nn.Linear(9216, bottleneck_dim),
nn.Dropout(0.5), nn.LeakyReLU(),
nn.LayerNorm(bottleneck_dim))
# Forward pass of subspace and prototype initialization data through feature extractor
subspace_data = self.fe(subspace_data)
prototype_data[0] = self.fe(prototype_data[0])
# Initialization of GTLVQ
self.gtlvq = GTLVQ(num_classes,
subspace_data,
prototype_data,
tangent_projection_type=tangent_projection_type,
feature_dim=bottleneck_dim,
prototypes_per_class=prototypes_per_class)
def forward(self, x):
# Feature Extraction
x = self.fe(x)
# GTLVQ Forward pass
dis = self.gtlvq(x)
return dis
# Get init data
subspace_data = torch.cat(
[next(iter(train_loader))[0],
next(iter(test_loader))[0]])
prototype_data = next(iter(train_loader))
# Build the CNN GTLVQ model
model = CNNGTLVQ(10,
subspace_data,
prototype_data,
tangent_projection_type="local",
bottleneck_dim=128).to(device)
# Optimize using SGD optimizer from `torch.optim`
optimizer = torch.optim.Adam([{
'params': model.fe.parameters()
}, {
'params': model.gtlvq.parameters()
}],
lr=learning_rate)
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
# Training loop
for epoch in range(n_epochs):
for batch_idx, (x_train, y_train) in enumerate(train_loader):
model.train()
x_train, y_train = x_train.to(device), y_train.to(device)
optimizer.zero_grad()
distances = model(x_train)
plabels = model.gtlvq.cls.prototype_labels.to(device)
# Compute loss.
loss = criterion([distances, plabels], y_train)
loss.backward()
optimizer.step()
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
model.gtlvq.orthogonalize_subspace()
if batch_idx % log_interval == 0:
acc = calculate_prototype_accuracy(distances, y_train, plabels)
print(
f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
Train Acc: {acc.item():02.02f}')
# Test
with torch.no_grad():
model.eval()
correct = 0
total = 0
for x_test, y_test in test_loader:
x_test, y_test = x_test.to(device), y_test.to(device)
test_distances = model(torch.tensor(x_test))
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
i = torch.argmin(test_distances, 1)
correct += torch.sum(y_test == test_plabels[i])
total += y_test.size(0)
print('Accuracy of the network on the test images: %d %%' %
(torch.true_divide(correct, total) * 100))
# Save the model
PATH = './glvq_mnist_model.pth'
torch.save(model.state_dict(), PATH)

View File

@ -3,14 +3,7 @@
# #############################################
# Core Setup
# #############################################
from importlib.metadata import version, PackageNotFoundError
VERSION_FALLBACK = "uninstalled_version"
try:
__version_core__ = version(__name__)
except PackageNotFoundError:
__version_core__ = VERSION_FALLBACK
pass
__version__ = "0.2.0"
from prototorch import datasets, functions, modules
@ -40,14 +33,14 @@ discovered_plugins = discover_plugins()
locals().update(discovered_plugins)
# Generate combines __version__ and __all__
__version_plugins__ = "\n".join(
version_plugins = "\n".join(
[
"- " + name + ": v" + plugin.__version__
for name, plugin in discovered_plugins.items()
]
)
if __version_plugins__ != "":
__version_plugins__ = "\nPlugins: \n" + __version_plugins__
if version_plugins != "":
version_plugins = "\nPlugins: \n" + version_plugins
__version__ = "core: v" + __version_core__ + __version_plugins__
version = "core: v" + __version__ + version_plugins
__all__ = __all_core__ + list(discovered_plugins.keys())

View File

@ -1,6 +1,8 @@
"""ProtoTorch distance functions."""
import torch
from prototorch.functions.helper import equal_int_shape, _int_and_mixed_shape, _check_shapes
import numpy as np
def squared_euclidean_distance(x, y):
@ -71,5 +73,155 @@ def lomega_distance(x, y, omegas):
return distances
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
r""" Computes an euclidean distanes matrix given two distinct vectors.
last dimension must be the vector dimension!
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
x.shape = (number_of_x_vectors, vector_dim)
y.shape = (number_of_y_vectors, vector_dim)
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
"""
for tensor in [x, y]:
if tensor.ndim != 2:
raise ValueError(
'The tensor dimension must be two. You provide: tensor.ndim=' +
str(tensor.ndim) + '.')
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
raise ValueError(
'The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]='
+ str(tuple(x.shape)[1]) + ' and tuple(y.shape)(y)[1]=' +
str(tuple(y.shape)[1]) + '.')
y = torch.transpose(y)
diss = torch.sum(x**2, axis=1,
keepdims=True) - 2 * torch.dot(x, y) + torch.sum(
y**2, axis=0, keepdims=True)
if not squared:
if epsilon == 0:
diss = torch.sqrt(diss)
else:
diss = torch.sqrt(torch.max(diss, epsilon))
return diss
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
r""" Tangent distances based on the tensorflow implementation of Sascha Saralajews
For more info about Tangen distances see DOI:10.1109/IJCNN.2016.7727534.
The subspaces is always assumed as transposed and must be orthogonal!
For local non sparse signals subspaces must be provided!
shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
shape(protos): proto_number x dim1 x dim2 x ... x dimN
shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
subspace should be orthogonalized
Pytorch implementation of Sascha Saralajew's tensorflow code.
Translation by Christoph Raab
"""
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
subspace_int_shape = tuple(subspaces.shape)
# check if the shapes are correct
_check_shapes(signal_int_shape, proto_int_shape)
atom_axes = list(range(3, len(signal_int_shape)))
# for sparse signals, we use the memory efficient implementation
if signal_int_shape[1] == 1:
signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
if len(atom_axes) > 1:
protos = torch.reshape(protos, [proto_shape[0], -1])
if subspaces.ndim == 2:
# clean solution without map if the matrix_scope is global
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
subspaces, torch.transpose(subspaces))
projected_signals = torch.dot(signals, projectors)
projected_protos = torch.dot(protos, projectors)
diss = euclidean_distance_matrix(projected_signals,
projected_protos,
squared=squared,
epsilon=epsilon)
diss = torch.reshape(
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
return torch.permute(diss, [0, 2, 1])
else:
# no solution without map possible --> memory efficient but slow!
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
subspaces,
subspaces) #K.batch_dot(subspaces, subspaces, [2, 2])
projected_protos = (protos @ subspaces
).T #K.batch_dot(projectors, protos, [1, 1]))
def projected_norm(projector):
return torch.sum(torch.dot(signals, projector)**2, axis=1)
diss = torch.transpose(map(projected_norm, projectors)) \
- 2 * torch.dot(signals, projected_protos) \
+ torch.sum(projected_protos**2, axis=0, keepdims=True)
if not squared:
if epsilon == 0:
diss = torch.sqrt(diss)
else:
diss = torch.sqrt(torch.max(diss, epsilon))
diss = torch.reshape(
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
return torch.permute(diss, [0, 2, 1])
else:
signals = signals.permute([0, 2, 1] + atom_axes)
diff = signals - protos
# global tangent space
if subspaces.ndim == 2:
#Scope Projectors
projectors = subspaces #
#Scope: Tangentspace Projections
diff = torch.reshape(
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
projected_diff = diff @ projectors
projected_diff = torch.reshape(
projected_diff,
(signal_shape[0], signal_shape[2], signal_shape[1]) +
signal_shape[3:])
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([0, 2, 1])
# local tangent spaces
else:
# Scope: Calculate Projectors
projectors = subspaces
# Scope: Tangentspace Projections
diff = torch.reshape(
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
diff = diff.permute([1, 0, 2])
projected_diff = torch.bmm(diff, projectors)
projected_diff = torch.reshape(
projected_diff,
(signal_shape[1], signal_shape[0], signal_shape[2]) +
signal_shape[3:])
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1)
# Aliases
sed = squared_euclidean_distance

View File

@ -0,0 +1,89 @@
import torch
def calculate_prototype_accuracy(y_pred, y_true, plabels):
"""Computes the accuracy of a prototype based model.
via Winner-Takes-All rule.
Requirement:
y_pred.shape == y_true.shape
unique(y_pred) in plabels
"""
with torch.no_grad():
idx = torch.argmin(y_pred, axis=1)
return torch.true_divide(torch.sum(y_true == plabels[idx]),
len(y_pred)) * 100
def predict_label(y_pred, plabels):
r""" Predicts labels given a prediction of a prototype based model.
"""
with torch.no_grad():
return plabels[torch.argmin(y_pred, 1)]
def mixed_shape(inputs):
if not torch.is_tensor(inputs):
raise ValueError('Input must be a tensor.')
else:
int_shape = list(inputs.shape)
# sometimes int_shape returns mixed integer types
int_shape = [int(i) if i is not None else i for i in int_shape]
tensor_shape = inputs.shape
for i, s in enumerate(int_shape):
if s is None:
int_shape[i] = tensor_shape[i]
return tuple(int_shape)
def equal_int_shape(shape_1, shape_2):
if not isinstance(shape_1,
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
raise ValueError('Input shapes must list or tuple.')
for shape in [shape_1, shape_2]:
if not all([isinstance(x, int) or x is None for x in shape]):
raise ValueError(
'Input shapes must be list or tuple of int and None values.')
if len(shape_1) != len(shape_2):
return False
else:
for axis, value in enumerate(shape_1):
if value is not None and shape_2[axis] not in {value, None}:
return False
return True
def _check_shapes(signal_int_shape, proto_int_shape):
if len(signal_int_shape) < 4:
raise ValueError(
"The number of signal dimensions must be >=4. You provide: " +
str(len(signal_int_shape)))
if len(proto_int_shape) < 2:
raise ValueError(
"The number of proto dimensions must be >=2. You provide: " +
str(len(proto_int_shape)))
if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]):
raise ValueError(
"The atom shape of signals must be equal protos. You provide: signals.shape[3:]="
+ str(signal_int_shape[3:]) + " != protos.shape[1:]=" +
str(proto_int_shape[1:]))
# not a sparse signal
if signal_int_shape[1] != 1:
if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]):
raise ValueError(
"If the signal is not sparse, the number of prototypes must be equal in signals and "
"protos. You provide: " + str(signal_int_shape[1]) + " != " +
str(proto_int_shape[0]))
return True
def _int_and_mixed_shape(tensor):
shape = mixed_shape(tensor)
int_shape = tuple([i if isinstance(i, int) else None for i in shape])
return shape, int_shape

View File

@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
import torch
def orthogonalization(tensors):
r""" Orthogonalization of a given tensor via polar decomposition.
"""
u, _, v = torch.svd(tensors, compute_uv=True)
u_shape = tuple(list(u.shape))
v_shape = tuple(list(v.shape))
# reshape to (num x N x M)
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
out = u @ v.permute([0, 2, 1])
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
return out
def trace_normalization(tensors):
r""" Trace normalization
"""
epsilon = torch.tensor([1e-10], dtype=torch.float64)
# Scope trace_normalization
constant = torch.trace(tensors)
if epsilon != 0:
constant = torch.max(constant, epsilon)
return tensors / constant

View File

@ -0,0 +1,190 @@
from torch import nn
import torch
from prototorch.modules.prototypes import Prototypes1D
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization
Parameters
----------
num_classes: int
Number of classes of the given classification problem.
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
Subspace data for the point approximation, required
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
prototype data for initalization of the prototypes used in GTLVQ.
subspace_size: int (default=256,optional)
Subspace dimension of the Projectors. Currently only supported
with tagnent_projection_type=global.
tangent_projection_type: string
Specifies the tangent projection type
options: local
local_proj
global
local: computes the tangent distances without emphasizing projected
data. Only distances are available
local_proj: computs tangent distances and returns the projected data
for further use. Be careful: data is repeated by number of prototypes
global: Number of subspaces is set to one and every prototypes
uses the same.
prototypes_per_class: int (default=2,optional)
Number of prototypes per class
feature_dim: int (default=256)
Dimensionality of the feature space specified as integer.
Prototype dimension.
Notes
-----
The GTLVQ [1] is a prototype-based classification learning model. The
GTLVQ uses the Tangent-Distances for a local point approximation
of an assumed data manifold via prototypial representations.
The GTLVQ requires subspace projectors for transforming the data
and prototypes into the affine subspace. Every prototype is
equipped with a specific subpspace and represents a point
approximation of the assumed manifold.
In practice prototypes and data are projected on this manifold
and pairwise euclidean distance computes.
References
----------
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
in classification based on manifolc. models and its relation
to tangent metric learning. In: 2017 International Joint
Conference on Neural Networks (IJCNN).
Bd. 2017-May : IEEE, 2017, S. 17561765
"""
def __init__(
self,
num_classes,
subspace_data=None,
prototype_data=None,
subspace_size=256,
tangent_projection_type='local',
prototypes_per_class=2,
feature_dim=256,
):
super(GTLVQ, self).__init__()
self.num_protos = num_classes * prototypes_per_class
self.subspace_size = feature_dim if subspace_size is None else subspace_size
self.feature_dim = feature_dim
if subspace_data is None:
raise ValueError('Init Data must be specified!')
self.tpt = tangent_projection_type
with torch.no_grad():
if self.tpt == 'local' or self.tpt == 'local_proj':
self.init_local_subspace(subspace_data)
elif self.tpt == 'global':
self.init_gobal_subspace(subspace_data, subspace_size)
else:
self.subspaces = None
# Hypothesis-Margin-Classifier
self.cls = Prototypes1D(input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
nclasses=num_classes,
prototype_initializer='stratified_mean',
data=prototype_data)
def forward(self, x):
# Tangent Projection
if self.tpt == 'local_proj':
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2)
dis, proj_x = self.local_tangent_projection(x_conform)
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
self.feature_dim)
return proj_x, dis
elif self.tpt == "local":
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2)
dis = tangent_distance(x_conform, self.cls.prototypes,
self.subspaces)
elif self.tpt == "gloabl":
dis = self.global_tangent_distances(x)
else:
dis = (x @ self.cls.prototypes.T) / (
torch.norm(x, dim=1, keepdim=True) @ torch.norm(
self.cls.prototypes, dim=1, keepdim=True).T)
return dis
def init_gobal_subspace(self, data, num_subspaces):
_, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = subspace[:, :num_subspaces]
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def init_local_subspace(self, data):
_, _, v = torch.svd(data)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
self.num_protos, 0)
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def global_tangent_distances(self, x):
# Tangent Projection
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
# Euclidean Distance
return euclidean_distance_matrix(x, projected_prototypes)
def local_tangent_projection(self,
signals):
# Note: subspaces is always assumed as transposed and must be orthogonal!
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
# subspace should be orthogonalized
# Origin Source Code
# Origin Author:
protos = self.cls.prototypes
subspaces = self.subspaces
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
_, proto_int_shape = _int_and_mixed_shape(protos)
# check if the shapes are correct
_check_shapes(signal_int_shape, proto_int_shape)
# Tangent Data Projections
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
data = signals.squeeze(2).permute([1, 0, 2])
projected_data = torch.bmm(data, subspaces)
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
diff = projected_data - projected_protos
projected_diff = torch.reshape(
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
signal_shape[3:])
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
def get_parameters(self):
return {
"params": self.cls.prototypes,
}, {
"params": self.subspaces
}
def orthogonalize_subspace(self):
if self.subspaces is not None:
with torch.no_grad():
ortho_subpsaces = orthogonalization(
self.subspaces
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
self.subspaces)
self.subspaces.copy_(ortho_subpsaces)

View File

@ -14,55 +14,24 @@ class _Prototypes(torch.nn.Module):
def _validate_prototype_distribution(self):
if 0 in self.prototype_distribution:
warnings.warn('Are you sure about the `0` in '
'`prototype_distribution`?')
warnings.warn("Are you sure about the `0` in "
"`prototype_distribution`?")
def extra_repr(self):
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
return f"prototypes.shape: {tuple(self.prototypes.shape)}"
def forward(self):
return self.prototypes, self.prototype_labels
class Prototypes1D(_Prototypes):
r"""Create a learnable set of one-dimensional prototypes.
"""Create a learnable set of one-dimensional prototypes.
TODO Complete this doc-string
Kwargs:
prototypes_per_class: number of prototypes to use per class.
Default: ``1``
prototype_initializer: prototype initializer.
Default: ``'ones'``
prototype_distribution: prototype distribution vector.
Default: ``None``
input_dim: dimension of the incoming data.
nclasses: number of classes.
data: If set to ``None``, data-dependent initializers will be ignored.
Default: ``None``
Shape:
- Input: :math:`(N, H_{in})`
where :math:`H_{in} = \text{input_dim}`.
- Output: :math:`(N, H_{out})`
where :math:`H_{out} = \text{total_prototypes}`.
Attributes:
prototypes: the learnable weights of the module of shape
:math:`(\text{total_prototypes}, \text{prototype_dimension})`.
prototype_labels: the non-learnable labels of the prototypes.
Examples:
>>> p = Prototypes1D(input_dim=20, nclasses=10)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([20, 10])
TODO Complete this doc-string.
"""
def __init__(self,
prototypes_per_class=1,
prototype_initializer='ones',
prototype_initializer="ones",
prototype_distribution=None,
data=None,
dtype=torch.float32,
@ -75,25 +44,25 @@ class Prototypes1D(_Prototypes):
prototype_distribution = prototype_distribution.tolist()
if data is None:
if 'input_dim' not in kwargs:
raise NameError('`input_dim` required if '
'no `data` is provided.')
if "input_dim" not in kwargs:
raise NameError("`input_dim` required if "
"no `data` is provided.")
if prototype_distribution:
kwargs_nclasses = sum(prototype_distribution)
else:
if 'nclasses' not in kwargs:
raise NameError('`prototype_distribution` required if '
'both `data` and `nclasses` are not '
'provided.')
kwargs_nclasses = kwargs.pop('nclasses')
input_dim = kwargs.pop('input_dim')
if "nclasses" not in kwargs:
raise NameError("`prototype_distribution` required if "
"both `data` and `nclasses` are not "
"provided.")
kwargs_nclasses = kwargs.pop("nclasses")
input_dim = kwargs.pop("input_dim")
if prototype_initializer in [
'stratified_mean', 'stratified_random'
"stratified_mean", "stratified_random"
]:
warnings.warn(
f'`prototype_initializer`: `{prototype_initializer}` '
'requires `data`, but `data` is not provided. '
'Using randomly generated data instead.')
f"`prototype_initializer`: `{prototype_initializer}` "
"requires `data`, but `data` is not provided. "
"Using randomly generated data instead.")
x_train = torch.rand(kwargs_nclasses, input_dim)
y_train = torch.arange(kwargs_nclasses)
if one_hot_labels:
@ -106,39 +75,39 @@ class Prototypes1D(_Prototypes):
nclasses = torch.unique(y_train, dim=-1).shape[-1]
if nclasses == 1:
warnings.warn('Are you sure about having one class only?')
warnings.warn("Are you sure about having one class only?")
if x_train.ndim != 2:
raise ValueError('`data[0].ndim != 2`.')
raise ValueError("`data[0].ndim != 2`.")
if y_train.ndim == 2:
if y_train.shape[1] == 1 and one_hot_labels:
raise ValueError('`one_hot_labels` is set to `True` '
'but target labels are not one-hot-encoded.')
raise ValueError("`one_hot_labels` is set to `True` "
"but target labels are not one-hot-encoded.")
if y_train.shape[1] != 1 and not one_hot_labels:
raise ValueError('`one_hot_labels` is set to `False` '
'but target labels in `data` '
'are one-hot-encoded.')
raise ValueError("`one_hot_labels` is set to `False` "
"but target labels in `data` "
"are one-hot-encoded.")
if y_train.ndim == 1 and one_hot_labels:
raise ValueError('`one_hot_labels` is set to `True` '
'but target labels are not one-hot-encoded.')
raise ValueError("`one_hot_labels` is set to `True` "
"but target labels are not one-hot-encoded.")
# Verify input dimension if `input_dim` is provided
if 'input_dim' in kwargs:
input_dim = kwargs.pop('input_dim')
if "input_dim" in kwargs:
input_dim = kwargs.pop("input_dim")
if input_dim != x_train.shape[1]:
raise ValueError(f'Provided `input_dim`={input_dim} does '
'not match data dimension '
f'`data[0].shape[1]`={x_train.shape[1]}')
raise ValueError(f"Provided `input_dim`={input_dim} does "
"not match data dimension "
f"`data[0].shape[1]`={x_train.shape[1]}")
# Verify the number of classes if `nclasses` is provided
if 'nclasses' in kwargs:
kwargs_nclasses = kwargs.pop('nclasses')
if "nclasses" in kwargs:
kwargs_nclasses = kwargs.pop("nclasses")
if kwargs_nclasses != nclasses:
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
'not match data labels '
'`torch.unique(data[1]).shape[0]`'
f'={nclasses}')
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
"not match data labels "
"`torch.unique(data[1]).shape[0]`"
f"={nclasses}")
super().__init__(**kwargs)

View File

@ -12,13 +12,6 @@ ProtoTorch Core Package
from setuptools import setup
from setuptools import find_packages
from pkg_resources import safe_name
import ast
import importlib.util
PKG_DIR = "prototorch"
PROJECT_URL = "https://github.com/si-cim/prototorch"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
@ -49,8 +42,8 @@ TESTS = ["pytest"]
ALL = DOCS + DATASETS + EXAMPLES + TESTS
setup(
name=safe_name(PKG_DIR),
use_scm_version=True,
name="prototorch",
version="0.2.0",
description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.",
@ -62,7 +55,6 @@ setup(
download_url=DOWNLOAD_URL,
license="MIT",
install_requires=INSTALL_REQUIRES,
setup_requires=["setuptools_scm"],
extras_require={
"docs": DOCS,
"datasets": DATASETS,