Update Documentation

Clean up project
This commit is contained in:
Alexander Engelsberger 2021-05-21 15:42:45 +02:00
parent a5e086ce0d
commit 7b4f7d84e0
11 changed files with 146 additions and 126 deletions

2
.gitignore vendored
View File

@ -133,3 +133,5 @@ datasets/
# PyTorch-Lightning # PyTorch-Lightning
lightning_logs/ lightning_logs/
.vscode/

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

View File

@ -104,7 +104,7 @@ autodoc_inherit_docstrings = False
# https://sphinx-themes.org/ # https://sphinx-themes.org/
html_theme = "sphinx_rtd_theme" html_theme = "sphinx_rtd_theme"
html_logo = "_static/img/horizontal-lockup.png" html_logo = "_static/img/logo.png"
html_theme_options = { html_theme_options = {
"logo_only": True, "logo_only": True,
@ -168,8 +168,8 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], man_pages = [(master_doc, "ProtoTorch Models",
1)] "ProtoTorch Models Plugin Documentation", [author], 1)]
# -- Options for Texinfo output ------------------------------------------- # -- Options for Texinfo output -------------------------------------------
@ -179,19 +179,22 @@ man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author],
texinfo_documents = [ texinfo_documents = [
( (
master_doc, master_doc,
"prototorch", "prototorch models",
"ProtoTorch Documentation", "ProtoTorch Models Plugin Documentation",
author, author,
"prototorch", "prototorch models",
"Prototype-based machine learning in PyTorch.", "Prototype-based machine learning Models in ProtoTorch.",
"Miscellaneous", "Miscellaneous",
), ),
] ]
# Example configuration for intersphinx: refer to the Python standard library. # Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = { intersphinx_mapping = {
"python": ("https://docs.python.org/", None), "python": ("https://docs.python.org/3/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None), "numpy": ("https://numpy.org/doc/stable/", None),
"torch": ('https://pytorch.org/docs/stable/', None),
"pytorch_lightning":
("https://pytorch-lightning.readthedocs.io/en/stable/", None),
} }
# -- Options for Epub output ---------------------------------------------- # -- Options for Epub output ----------------------------------------------

9
docs/source/custom.rst Normal file
View File

@ -0,0 +1,9 @@
.. Customize the Models
Abstract Models
========================================
.. autoclass:: prototorch.models.abstract.AbstractPrototypeModel
:members:
.. autoclass:: prototorch.models.abstract.PrototypeImageModel
:members:

View File

@ -1,25 +1,40 @@
.. ProtoTorch Models documentation master file .. ProtoTorch Models documentation master file
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
About ProtoTorch Models ProtoTorch Models Plugins
======================== ========================================
.. toctree::
:hidden:
:maxdepth: 3
self
tutorial.ipynb
.. toctree:: .. toctree::
:hidden: :hidden:
:maxdepth: 3 :maxdepth: 3
:caption: Contents: :caption: Library
self library
models
tutorial.ipynb
.. toctree::
:hidden:
:maxdepth: 3
:caption: Customize
custom
About
-----------------------------------------
`Prototorch Models <https://github.com/si-cim/prototorch_models>`_ is a Plugin `Prototorch Models <https://github.com/si-cim/prototorch_models>`_ is a Plugin
for `Prototorch <https://github.com/si-cim/prototorch>`_. It implements common for `Prototorch <https://github.com/si-cim/prototorch>`_. It implements common
prototype-based Machine Learning algorithms using `PyTorch-Lightning prototype-based Machine Learning algorithms using `PyTorch-Lightning
<https://www.pytorchlightning.ai/>`_. <https://www.pytorchlightning.ai/>`_.
Indices Library
======= -----------------------------------------
* :ref:`genindex` Prototorch Models delivers many application ready models.
* :ref:`modindex` These models have been published in the past and have been adapted to the Prototorch library.
Customizable
-----------------------------------------
Prototorch Models also contains the building blocks to build own models with PyTorch-Lightning and Prototorch.

View File

@ -1,27 +1,35 @@
.. Available Models .. Available Models
Available Models Models
======================================== ========================================
Unsupervised Methods Unsupervised Methods
----------------------------------------- -----------------------------------------
.. autoclass:: prototorch.models.knn.KNN .. autoclass:: prototorch.models.unsupervised.KNN
:members: :members:
.. autoclass:: prototorch.models.neural_gas.NeuralGas .. autoclass:: prototorch.models.unsupervised.NeuralGas
:members: :members:
Classical Learning Vector Quantization Classical Learning Vector Quantization
----------------------------------------- -----------------------------------------
Original LVQ models. Implementations use GLVQ structure as shown in [Sato&Yamada]. Original LVQ models by Kohonen.
These heuristic algorithms do not use gradient descent.
.. autoclass:: prototorch.models.glvq.LVQ1 .. autoclass:: prototorch.models.glvq.LVQ1
:members: :members:
.. autoclass:: prototorch.models.glvq.LVQ21 .. autoclass:: prototorch.models.glvq.LVQ21
:members: :members:
It is also possible to use the GLVQ structure as shown in [Sato&Yamada].
This allows the use of gradient descent methods.
.. autoclass:: prototorch.models.glvq.GLVQ1
:members:
.. autoclass:: prototorch.models.glvq.GLVQ21
:members:
Generalized Learning Vector Quantization Generalized Learning Vector Quantization
----------------------------------------- -----------------------------------------
@ -43,10 +51,17 @@ Generalized Learning Vector Quantization
.. autoclass:: prototorch.models.glvq.LVQMLN .. autoclass:: prototorch.models.glvq.LVQMLN
:members: :members:
CBC Classification by Component
----------------------------------------- -----------------------------------------
.. autoclass:: prototorch.models.cbc.CBC .. autoclass:: prototorch.models.cbc.CBC
:members: :members:
.. autoclass:: prototorch.models.cbc.ImageCBC .. autoclass:: prototorch.models.cbc.ImageCBC
:members: :members:
Visualization
========================================
.. automodule:: prototorch.models.vis
:members:
:undoc-members:

View File

@ -3,8 +3,7 @@ from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC, ImageCBC from .cbc import CBC, ImageCBC
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ) ImageGLVQ, ImageGMLVQ, SiameseGLVQ)
from .knn import KNN from .unsupervised import KNN, NeuralGas
from .neural_gas import NeuralGas
from .vis import * from .vis import *
__version__ = "0.1.7" __version__ = "0.1.7"

View File

@ -1,14 +0,0 @@
"""Callbacks for Pytorch Lighning Modules"""
import pytorch_lightning as pl
import torch
class StopOnNaN(pl.Callback):
def __init__(self, param):
super().__init__()
self.param = param
def on_epoch_end(self, trainer, pl_module, logs={}):
if torch.isnan(self.param).any():
raise ValueError("NaN encountered. Stopping.")

View File

@ -1,3 +1,4 @@
"""Models based on the GLVQ Framework"""
import torch import torch
import torchmetrics import torchmetrics
from prototorch.components import LabeledComponents from prototorch.components import LabeledComponents
@ -6,15 +7,8 @@ from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance, from prototorch.functions.distances import (euclidean_distance, omega_distance,
sed) sed)
from prototorch.functions.helper import get_flat from prototorch.functions.helper import get_flat
from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss, from prototorch.functions.losses import (_get_dp_dm, glvq_loss, lvq1_loss,
lvq1_loss, lvq21_loss) lvq21_loss)
from .abstract import AbstractPrototypeModel, PrototypeImageModel
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
from .abstract import AbstractPrototypeModel, PrototypeImageModel from .abstract import AbstractPrototypeModel, PrototypeImageModel
@ -192,11 +186,14 @@ class GRLVQ(SiameseGLVQ):
self.relevances = torch.nn.parameter.Parameter( self.relevances = torch.nn.parameter.Parameter(
torch.ones(self.hparams.input_dim)) torch.ones(self.hparams.input_dim))
# Overwrite backbone
self.backbone = self._backbone
@property @property
def relevance_profile(self): def relevance_profile(self):
return self.relevances.detach().cpu() return self.relevances.detach().cpu()
def backbone(self, x): def _backbone(self, x):
"""Namespace hook for the visualization callbacks to work.""" """Namespace hook for the visualization callbacks to work."""
return x @ torch.diag(self.relevances) return x @ torch.diag(self.relevances)
@ -262,6 +259,7 @@ class LVQMLN(SiameseGLVQ):
class NonGradientGLVQ(GLVQ): class NonGradientGLVQ(GLVQ):
"""Abstract Model for Models that do not use gradients in their update phase."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.automatic_optimization = False self.automatic_optimization = False
@ -271,6 +269,7 @@ class NonGradientGLVQ(GLVQ):
class LVQ1(NonGradientGLVQ): class LVQ1(NonGradientGLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components protos = self.proto_layer.components
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
@ -299,6 +298,7 @@ class LVQ1(NonGradientGLVQ):
class LVQ21(NonGradientGLVQ): class LVQ21(NonGradientGLVQ):
"""Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components protos = self.proto_layer.components
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
@ -311,8 +311,7 @@ class LVQ21(NonGradientGLVQ):
xi = xi.view(1, -1) xi = xi.view(1, -1)
yi = yi.view(1, ) yi = yi.view(1, )
d = self(xi) d = self(xi)
preds = wtac(d, plabels) (_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
(dp, wp), (dn, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
shiftp = xi - protos[wp] shiftp = xi - protos[wp]
shiftn = protos[wn] - xi shiftn = protos[wn] - xi
updated_protos = protos + 0.0 updated_protos = protos + 0.0
@ -328,11 +327,11 @@ class LVQ21(NonGradientGLVQ):
class MedianLVQ(NonGradientGLVQ): class MedianLVQ(NonGradientGLVQ):
... """Median LVQ"""
class GLVQ1(GLVQ): class GLVQ1(GLVQ):
"""Learning Vector Quantization 1.""" """Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.loss = lvq1_loss self.loss = lvq1_loss
@ -340,7 +339,7 @@ class GLVQ1(GLVQ):
class GLVQ21(GLVQ): class GLVQ21(GLVQ):
"""Learning Vector Quantization 2.1.""" """Generalized Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.loss = lvq21_loss self.loss = lvq21_loss
@ -354,7 +353,6 @@ class ImageGLVQ(PrototypeImageModel, GLVQ):
after updates. after updates.
""" """
pass
class ImageGMLVQ(PrototypeImageModel, GMLVQ): class ImageGMLVQ(PrototypeImageModel, GMLVQ):
@ -364,4 +362,3 @@ class ImageGMLVQ(PrototypeImageModel, GMLVQ):
after updates. after updates.
""" """
pass

View File

@ -1,62 +0,0 @@
"""The popular K-Nearest-Neighbors classification algorithm."""
import warnings
import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.components.initializers import parse_data_arg
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance
from .abstract import AbstractPrototypeModel
class KNN(AbstractPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("k", 1)
self.hparams.setdefault("distance", euclidean_distance)
data = kwargs.get("data")
x_train, y_train = parse_data_arg(data)
self.proto_layer = LabeledComponents(initialized_components=(x_train,
y_train))
self.train_acc = torchmetrics.Accuracy()
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach()
def forward(self, x):
protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos)
return dis
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.component_labels
y_pred = knnc(d, plabels, k=self.hparams.k)
return y_pred
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None

View File

@ -1,7 +1,13 @@
"""Unsupervised prototype learning algorithms."""
import warnings
import torch import torch
from prototorch.components import Components import torchmetrics
from prototorch.components import Components, LabeledComponents
from prototorch.components import initializers as cinit from prototorch.components import initializers as cinit
from prototorch.components.initializers import ZerosInitializer from prototorch.components.initializers import ZerosInitializer, parse_data_arg
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import NeuralGasEnergy from prototorch.modules.losses import NeuralGasEnergy
@ -36,6 +42,56 @@ class ConnectionTopology(torch.nn.Module):
return f"agelimit: {self.agelimit}" return f"agelimit: {self.agelimit}"
class KNN(AbstractPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("k", 1)
self.hparams.setdefault("distance", euclidean_distance)
data = kwargs.get("data")
x_train, y_train = parse_data_arg(data)
self.proto_layer = LabeledComponents(initialized_components=(x_train,
y_train))
self.train_acc = torchmetrics.Accuracy()
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach()
def forward(self, x):
protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos)
return dis
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.component_labels
y_pred = knnc(d, plabels, k=self.hparams.k)
return y_pred
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None
class NeuralGas(AbstractPrototypeModel): class NeuralGas(AbstractPrototypeModel):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__() super().__init__()