Compare commits
83 Commits
feature/je
...
v1.0.0a8
Author | SHA1 | Date | |
---|---|---|---|
|
9bb2e20dce | ||
|
6748951b63 | ||
|
c547af728b | ||
|
482044ec87 | ||
|
45f01f39d4 | ||
|
9ab864fbdf | ||
|
365e0fb931 | ||
|
ba50dfba50 | ||
|
16ca409f07 | ||
|
c3cad19853 | ||
|
ec294bdd37 | ||
|
e0abb1f3de | ||
|
918e599c6a | ||
|
ec61881ca8 | ||
|
5a89f24c10 | ||
|
bcf9c6bdb1 | ||
|
736565b768 | ||
|
94730f492b | ||
|
46ec7b07d7 | ||
|
07dab5a5ca | ||
|
ed83138e1f | ||
|
1be7d7ec09 | ||
|
60d2a1d2c9 | ||
|
be7d7f43bd | ||
|
fe729781fc | ||
|
a7df7be1c8 | ||
|
696719600b | ||
|
48e7c029fa | ||
|
5de3a480c7 | ||
|
626f51ce80 | ||
|
6d7d93c8e8 | ||
|
93b1d0bd46 | ||
|
b7992c01db | ||
|
fcd944d3ff | ||
|
054720dd7b | ||
|
23d1a71b31 | ||
|
e922aae432 | ||
|
3e50d0d817 | ||
|
dc4f31d700 | ||
|
02954044d7 | ||
|
8f08ba66ea | ||
|
e0b92e9ac2 | ||
|
d16a0de202 | ||
|
76fea3f881 | ||
|
c00513ae0d | ||
|
bccef8bef0 | ||
|
29ee326b85 | ||
|
055568dc86 | ||
|
3a7328e290 | ||
|
d6629c8792 | ||
|
ef65bd3789 | ||
|
d096eba2c9 | ||
|
dd34c57e2e | ||
|
5911f4dd90 | ||
|
dbfe315f4f | ||
|
9c90c902dc | ||
|
7d3f59e54b | ||
|
9da47b1dba | ||
|
41f0e77fc9 | ||
|
fab786a07e | ||
|
40bd7ed380 | ||
|
4941c2b89d | ||
|
ce14dec7e9 | ||
|
b31c8cc707 | ||
|
e21e6c7e02 | ||
|
dd696ea1e0 | ||
|
15e7232747 | ||
|
197b728c63 | ||
|
98892afee0 | ||
|
d5855dbe97 | ||
|
75a39f5b03 | ||
|
1a0e697b27 | ||
|
1a17193b35 | ||
|
aaa3c51e0a | ||
|
62c5974a85 | ||
|
1d26226a2f | ||
|
4232d0ed2a | ||
|
a9edf06507 | ||
|
d3bb430104 | ||
|
6ffd27d12a | ||
|
859e2cae69 | ||
|
d7ea89d47e | ||
|
fa928afe2c |
@@ -1,9 +1,11 @@
|
||||
[bumpversion]
|
||||
current_version = 0.3.0
|
||||
current_version = 1.0.0a8
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||
serialize = {major}.{minor}.{patch}
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?
|
||||
serialize =
|
||||
{major}.{minor}.{patch}-{release}
|
||||
{major}.{minor}.{patch}
|
||||
message = build: bump version {current_version} → {new_version}
|
||||
|
||||
[bumpversion:file:setup.py]
|
||||
|
15
.codacy.yml
15
.codacy.yml
@@ -1,15 +0,0 @@
|
||||
# To validate the contents of your configuration file
|
||||
# run the following command in the folder where the configuration file is located:
|
||||
# codacy-analysis-cli validate-configuration --directory `pwd`
|
||||
# To analyse, run:
|
||||
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
|
||||
---
|
||||
engines:
|
||||
pylintpython3:
|
||||
exclude_paths:
|
||||
- config/engines.yml
|
||||
remark-lint:
|
||||
exclude_paths:
|
||||
- config/engines.yml
|
||||
exclude_paths:
|
||||
- 'tests/**'
|
@@ -1,2 +0,0 @@
|
||||
comment:
|
||||
require_changes: yes
|
25
.github/workflows/examples.yml
vendored
Normal file
25
.github/workflows/examples.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
# Thi workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: examples
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'examples/**.py'
|
||||
jobs:
|
||||
cpu:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
- name: Run examples
|
||||
run: |
|
||||
./tests/test_examples.sh examples/
|
76
.github/workflows/pythonapp.yml
vendored
Normal file
76
.github/workflows/pythonapp.yml
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
style:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
compatibility:
|
||||
needs: style
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
||||
os: [ubuntu-latest, windows-latest]
|
||||
exclude:
|
||||
- os: windows-latest
|
||||
python-version: "3.7"
|
||||
- os: windows-latest
|
||||
python-version: "3.8"
|
||||
- os: windows-latest
|
||||
python-version: "3.9"
|
||||
- os: windows-latest
|
||||
python-version: "3.11"
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest
|
||||
publish_pypi:
|
||||
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
|
||||
needs: compatibility
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
pip install wheel
|
||||
- name: Build package
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish a Python distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
@@ -3,9 +3,10 @@
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.0.1
|
||||
rev: v4.3.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude: (^\.bumpversion\.cfg$|cli_messages\.py)
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
@@ -13,24 +14,24 @@ repos:
|
||||
- id: check-case-conflict
|
||||
|
||||
- repo: https://github.com/myint/autoflake
|
||||
rev: v1.4
|
||||
rev: v1.7.7
|
||||
hooks:
|
||||
- id: autoflake
|
||||
|
||||
- repo: http://github.com/PyCQA/isort
|
||||
rev: 5.8.0
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.902
|
||||
rev: v0.982
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: prototorch
|
||||
additional_dependencies: [types-pkg_resources]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: v0.31.0
|
||||
rev: v0.32.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
|
||||
@@ -42,7 +43,7 @@ repos:
|
||||
- id: python-check-blanket-noqa
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.19.4
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
@@ -51,3 +52,8 @@ repos:
|
||||
hooks:
|
||||
- id: gitlint
|
||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
||||
|
||||
- repo: https://github.com/dosisod/refurb
|
||||
rev: v1.4.0
|
||||
hooks:
|
||||
- id: refurb
|
||||
|
44
.travis.yml
44
.travis.yml
@@ -1,44 +0,0 @@
|
||||
dist: bionic
|
||||
sudo: false
|
||||
language: python
|
||||
python:
|
||||
- 3.9
|
||||
- 3.8
|
||||
- 3.7
|
||||
- 3.6
|
||||
cache:
|
||||
directories:
|
||||
- "$HOME/.cache/pip"
|
||||
- "./tests/artifacts"
|
||||
- "$HOME/datasets"
|
||||
install:
|
||||
- pip install git+git://github.com/si-cim/prototorch@dev --progress-bar off
|
||||
- pip install .[all] --progress-bar off
|
||||
script:
|
||||
- coverage run -m pytest
|
||||
- ./tests/test_examples.sh examples/
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
|
||||
# Publish on PyPI
|
||||
jobs:
|
||||
include:
|
||||
- stage: build
|
||||
python: 3.9
|
||||
script: echo "Starting Pypi build"
|
||||
deploy:
|
||||
provider: pypi
|
||||
username: __token__
|
||||
distributions: "sdist bdist_wheel"
|
||||
password:
|
||||
secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
|
||||
on:
|
||||
tags: true
|
||||
skip_existing: true
|
||||
|
||||
# The password is encrypted with:
|
||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
||||
# See https://docs.travis-ci.com/user/deployment/pypi and
|
||||
# https://github.com/travis-ci/travis.rb#installation
|
||||
# for more details
|
||||
# Note: The encrypt command does not work well in ZSH.
|
@@ -1,6 +1,5 @@
|
||||
# ProtoTorch Models
|
||||
|
||||
[](https://travis-ci.com/github/si-cim/prototorch_models)
|
||||
[](https://github.com/si-cim/prototorch_models/releases)
|
||||
[](https://pypi.org/project/prototorch_models/)
|
||||
[](https://github.com/si-cim/prototorch_models/blob/master/LICENSE)
|
||||
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
#
|
||||
release = "0.3.0"
|
||||
release = "1.0.0-a8"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
|
@@ -23,6 +23,13 @@ ProtoTorch Models Plugins
|
||||
|
||||
custom
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 3
|
||||
:caption: Proto Y Architecture
|
||||
|
||||
y-architecture
|
||||
|
||||
About
|
||||
-----------------------------------------
|
||||
`Prototorch Models <https://github.com/si-cim/prototorch_models>`_ is a Plugin
|
||||
@@ -33,8 +40,10 @@ prototype-based Machine Learning algorithms using `PyTorch-Lightning
|
||||
Library
|
||||
-----------------------------------------
|
||||
Prototorch Models delivers many application ready models.
|
||||
These models have been published in the past and have been adapted to the Prototorch library.
|
||||
These models have been published in the past and have been adapted to the
|
||||
Prototorch library.
|
||||
|
||||
Customizable
|
||||
-----------------------------------------
|
||||
Prototorch Models also contains the building blocks to build own models with PyTorch-Lightning and Prototorch.
|
||||
Prototorch Models also contains the building blocks to build own models with
|
||||
PyTorch-Lightning and Prototorch.
|
||||
|
@@ -71,7 +71,7 @@ Probabilistic Models
|
||||
Probabilistic variants assume, that the prototypes generate a probability distribution over the classes.
|
||||
For a test sample they return a distribution instead of a class assignment.
|
||||
|
||||
The following two algorihms were presented by :cite:t:`seo2003` .
|
||||
The following two algorithms were presented by :cite:t:`seo2003` .
|
||||
Every prototypes is a center of a gaussian distribution of its class, generating a mixture model.
|
||||
|
||||
.. autoclass:: prototorch.models.probabilistic.SLVQ
|
||||
@@ -80,7 +80,7 @@ Every prototypes is a center of a gaussian distribution of its class, generating
|
||||
.. autoclass:: prototorch.models.probabilistic.RSLVQ
|
||||
:members:
|
||||
|
||||
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incooperate the winning rank into the prior probability calculation.
|
||||
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incorporate the winning rank into the prior probability calculation.
|
||||
And second use divergence as loss function.
|
||||
|
||||
.. autoclass:: prototorch.models.probabilistic.PLVQ
|
||||
@@ -106,7 +106,7 @@ Visualization
|
||||
Visualization is very specific to its application.
|
||||
PrototorchModels delivers visualization for two dimensional data and image data.
|
||||
|
||||
The visulizations can be shown in a seperate window and inside a tensorboard.
|
||||
The visualizations can be shown in a separate window and inside a tensorboard.
|
||||
|
||||
.. automodule:: prototorch.models.vis
|
||||
:members:
|
||||
|
@@ -2,223 +2,252 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7ac5eff0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# A short tutorial for the `prototorch.models` plugin"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "beb83780",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Introduction"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43b74278",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework.\n",
|
||||
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
|
||||
"\n",
|
||||
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
|
||||
"\n",
|
||||
"With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4e5d1fad",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Basics"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1244b66b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dcb88e8a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import prototorch as pt\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"import torch"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1adbe2f8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Building Models"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "96663ab1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "819ba756",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = pt.models.GLVQ(\n",
|
||||
" hparams=dict(distribution=[1, 1, 1]),\n",
|
||||
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
|
||||
")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b37e97c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(model)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d2c86903",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n",
|
||||
"\n",
|
||||
"The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "45806052",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Data"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9d62c4c6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "504df02c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds = pt.datasets.Iris(dims=[0, 2])"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3b8e7756",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"type(train_ds)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bce43afa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds.data.shape, train_ds.targets.shape"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "26a83328",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "67b80fbe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c1185f31",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"type(train_loader)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9b5a8963",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x_batch, y_batch = next(iter(train_loader))\n",
|
||||
"print(f\"{x_batch=}, {y_batch=}\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dd492ee2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5176b055",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Training"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "46a7a506",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "279e75b7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer = pl.Trainer(max_epochs=2, weights_summary=None)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e496b492",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.fit(model, train_loader)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "497fbff6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### From data to a trained model - a very minimal example"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ab069c5d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds = pt.datasets.Iris(dims=[0, 2])\n",
|
||||
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n",
|
||||
@@ -230,49 +259,239 @@
|
||||
"\n",
|
||||
"trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n",
|
||||
"trainer.fit(model, train_loader)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "30c71a93",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced"
|
||||
],
|
||||
"metadata": {}
|
||||
"### Saving/Loading trained models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f74ed2c1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initializing prototypes with a subset of a dataset (along with transformations)"
|
||||
],
|
||||
"metadata": {}
|
||||
"Pytorch Lightning can automatically checkpoint the model during various stages of training, but it also possible to manually save a checkpoint after training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3156658d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ckpt_path = \"./checkpoints/glvq_iris.ckpt\"\n",
|
||||
"trainer.save_checkpoint(ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c1c34055",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loaded_model = pt.models.GLVQ.load_from_checkpoint(ckpt_path, strict=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bbbb08e9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Visualizing decision boundaries in 2D"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "53ca52dc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisGLVQ2D(data=train_ds).visualize(loaded_model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8373531f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Saving/Loading trained weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "937bc458",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In most cases, the checkpointing workflow is sufficient. In some cases however, one might want to only save the trained weights from the model. The disadvantage of this method is that the model has be re-created using compatible initialization parameters before the weights could be loaded."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f2035af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ckpt_path = \"./checkpoints/glvq_iris_weights.pth\"\n",
|
||||
"torch.save(model.state_dict(), ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1206021a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = pt.models.GLVQ(\n",
|
||||
" dict(distribution=(3, 2)),\n",
|
||||
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9f2a4beb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisGLVQ2D(data=train_ds, title=\"Before loading the weights\").visualize(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "528d2fc2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.load(ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec817e6b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.load_state_dict(torch.load(ckpt_path), strict=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a208eab7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisGLVQ2D(data=train_ds, title=\"After loading the weights\").visualize(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f8de748f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "53a64063",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Warm-start a model with prototypes learned from another model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3177c277",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
|
||||
"model = pt.models.SiameseGMLVQ(\n",
|
||||
" dict(input_dim=2,\n",
|
||||
" latent_dim=2,\n",
|
||||
" distribution=(3, 2),\n",
|
||||
" proto_lr=0.0001,\n",
|
||||
" bb_lr=0.0001),\n",
|
||||
" optimizer=torch.optim.Adam,\n",
|
||||
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
|
||||
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
|
||||
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8baee9a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc203088",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1f6a33a5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initializing prototypes with a subset of a dataset (along with transformations)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "946ce341",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import prototorch as pt\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"import torch\n",
|
||||
"from torchvision import transforms\n",
|
||||
"from torchvision.datasets import MNIST"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torchvision.utils import make_grid"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "510d9bd4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from matplotlib import pyplot as plt"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ea7c1228",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds = MNIST(\n",
|
||||
" \"~/datasets\",\n",
|
||||
@@ -284,59 +503,87 @@
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]),\n",
|
||||
")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b9eaf5c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"s = int(0.05 * len(train_ds))\n",
|
||||
"init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8c32c9f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"init_ds"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "68a9a8b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = pt.models.ImageGLVQ(\n",
|
||||
" dict(distribution=(10, 5)),\n",
|
||||
" dict(distribution=(10, 1)),\n",
|
||||
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||
")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"plt.imshow(model.get_prototype_grid(num_columns=10))"
|
||||
],
|
||||
"id": "6f23df86",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
"source": [
|
||||
"plt.imshow(model.get_prototype_grid(num_columns=5))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c23c7b2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "30780927",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"protos, plabels = pt.components.LabeledComponents(\n",
|
||||
" distribution=(10, 5),\n",
|
||||
" components_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
|
||||
")()\n",
|
||||
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4fa69f92",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## FAQs"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fa20f9ac",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### How do I Retrieve the prototypes and their respective labels from the model?\n",
|
||||
"\n",
|
||||
@@ -351,11 +598,12 @@
|
||||
"```python\n",
|
||||
">>> model.prototype_labels\n",
|
||||
"```"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ba8215bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### How do I make inferences/predictions/recall with my trained model?\n",
|
||||
"\n",
|
||||
@@ -370,13 +618,12 @@
|
||||
"```python\n",
|
||||
">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n",
|
||||
"```"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -390,7 +637,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.4"
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
71
docs/source/y-architecture.rst
Normal file
71
docs/source/y-architecture.rst
Normal file
@@ -0,0 +1,71 @@
|
||||
.. Documentation of the updated Architecture.
|
||||
|
||||
Proto Y Architecture
|
||||
========================================
|
||||
|
||||
Overview
|
||||
****************************************
|
||||
|
||||
The Proto Y Architecture is a framework for abstract prototype learning methods.
|
||||
|
||||
It divides the problem into multiple steps:
|
||||
|
||||
* **Components** : Recalling the position and metadata of the components/prototypes.
|
||||
* **Backbone** : Apply a mapping function to data and prototypes.
|
||||
* **Comparison** : Calculate a dissimilarity based on the latent positions.
|
||||
* **Competition** : Calculate competition values based on the comparison and the metadata.
|
||||
* **Loss** : Calculate the loss based on the competition values
|
||||
* **Inference** : Predict the output based on the competition values.
|
||||
|
||||
Depending on the phase (Training or Testing) Loss or Inference is used.
|
||||
|
||||
Inheritance Structure
|
||||
****************************************
|
||||
|
||||
The Proto Y Architecture has a single base class that defines all steps and hooks
|
||||
of the architecture.
|
||||
|
||||
.. autoclass:: prototorch.y.architectures.base.BaseYArchitecture
|
||||
|
||||
**Steps**
|
||||
|
||||
Components
|
||||
|
||||
.. automethod:: init_components
|
||||
.. automethod:: components
|
||||
|
||||
Backbone
|
||||
|
||||
.. automethod:: init_backbone
|
||||
.. automethod:: backbone
|
||||
|
||||
Comparison
|
||||
|
||||
.. automethod:: init_comparison
|
||||
.. automethod:: comparison
|
||||
|
||||
Competition
|
||||
|
||||
.. automethod:: init_competition
|
||||
.. automethod:: competition
|
||||
|
||||
Loss
|
||||
|
||||
.. automethod:: init_loss
|
||||
.. automethod:: loss
|
||||
|
||||
Inference
|
||||
|
||||
.. automethod:: init_inference
|
||||
.. automethod:: inference
|
||||
|
||||
**Hooks**
|
||||
|
||||
Torchmetric
|
||||
|
||||
.. automethod:: register_torchmetric
|
||||
|
||||
Hyperparameters
|
||||
****************************************
|
||||
Every model implemented with the Proto Y Architecture has a set of hyperparameters,
|
||||
which is stored in the ``HyperParameters`` attribute of the architecture.
|
@@ -1,53 +0,0 @@
|
||||
"""CBC example using the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=42)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[1, 0, 3],
|
||||
margin=0.1,
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.CBC(
|
||||
hparams,
|
||||
components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
|
||||
reasonings_iniitializer=pt.initializers.
|
||||
PurePositiveReasoningsInitializer(),
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisCBC2D(data=train_ds,
|
||||
title="CBC Iris Example",
|
||||
resolution=100,
|
||||
axis_off=True)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,8 +0,0 @@
|
||||
# Examples using Lightning CLI
|
||||
|
||||
Examples in this folder use the experimental [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html).
|
||||
|
||||
To use the example run
|
||||
```
|
||||
python gmlvq.py --config gmlvq.yaml
|
||||
```
|
@@ -1,19 +0,0 @@
|
||||
"""GMLVQ example using the MNIST dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import torch
|
||||
from prototorch.models import ImageGMLVQ
|
||||
from prototorch.models.abstract import PrototypeModel
|
||||
from prototorch.models.data import MNISTDataModule
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
|
||||
class ExperimentClass(ImageGMLVQ):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.zeros(28 * 28),
|
||||
**kwargs)
|
||||
|
||||
|
||||
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)
|
@@ -1,11 +0,0 @@
|
||||
model:
|
||||
hparams:
|
||||
input_dim: 784
|
||||
latent_dim: 784
|
||||
distribution:
|
||||
num_classes: 10
|
||||
prototypes_per_class: 2
|
||||
proto_lr: 0.01
|
||||
bb_lr: 0.01
|
||||
data:
|
||||
batch_size: 32
|
@@ -1,81 +0,0 @@
|
||||
"""Dynamically prune 'loser' prototypes in GLVQ-type models."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
num_classes = 4
|
||||
num_features = 2
|
||||
num_clusters = 1
|
||||
train_ds = pt.datasets.Random(num_samples=500,
|
||||
num_classes=num_classes,
|
||||
num_features=num_features,
|
||||
num_clusters=num_clusters,
|
||||
separation=3.0,
|
||||
seed=42)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
|
||||
|
||||
# Hyperparameters
|
||||
prototypes_per_class = num_clusters * 5
|
||||
hparams = dict(
|
||||
distribution=(num_classes, prototypes_per_class),
|
||||
lr=0.2,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.CELVQ(
|
||||
hparams,
|
||||
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(train_ds)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.01, # prune prototype if it wins less than 1%
|
||||
idle_epochs=20, # pruning too early may cause problems
|
||||
prune_quota_per_epoch=2, # prune at most 2 prototypes per epoch
|
||||
frequency=1, # prune every epoch
|
||||
verbose=True,
|
||||
)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=0.001,
|
||||
patience=20,
|
||||
mode="min",
|
||||
verbose=True,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
pruning,
|
||||
es,
|
||||
],
|
||||
progress_bar_refresh_rate=0,
|
||||
terminate_on_nan=True,
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,55 +0,0 @@
|
||||
"""GLVQ example using the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution={
|
||||
"num_classes": 3,
|
||||
"per_class": 4
|
||||
},
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
lr_scheduler=ExponentialLR,
|
||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,76 +0,0 @@
|
||||
"""GLVQ example using the spiral dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
|
||||
|
||||
# Hyperparameters
|
||||
num_classes = 2
|
||||
prototypes_per_class = 10
|
||||
hparams = dict(
|
||||
distribution=(num_classes, prototypes_per_class),
|
||||
transfer_function="swish_beta",
|
||||
transfer_beta=10.0,
|
||||
proto_lr=0.1,
|
||||
bb_lr=0.1,
|
||||
input_dim=2,
|
||||
latent_dim=2,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GMLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(
|
||||
train_ds,
|
||||
show_last_only=False,
|
||||
block=False,
|
||||
)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.01,
|
||||
idle_epochs=10,
|
||||
prune_quota_per_epoch=5,
|
||||
frequency=5,
|
||||
replace=True,
|
||||
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
|
||||
verbose=True,
|
||||
)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=1.0,
|
||||
patience=5,
|
||||
mode="min",
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
es,
|
||||
pruning,
|
||||
],
|
||||
terminate_on_nan=True,
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
144
examples/gmlvq_iris.py
Normal file
144
examples/gmlvq_iris.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torchmetrics
|
||||
from prototorch.core import SMCI, PCALinearTransformInitializer
|
||||
from prototorch.datasets import Iris
|
||||
from prototorch.models.architectures.base import Steps
|
||||
from prototorch.models.callbacks import (
|
||||
LogTorchmetricCallback,
|
||||
PlotLambdaMatrixToTensorboard,
|
||||
VisGMLVQ2D,
|
||||
)
|
||||
from prototorch.models.library.gmlvq import GMLVQ
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# ##############################################################################
|
||||
|
||||
|
||||
def main():
|
||||
# ------------------------------------------------------------
|
||||
# DATA
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Dataset
|
||||
full_dataset = Iris()
|
||||
full_count = len(full_dataset)
|
||||
|
||||
train_count = int(full_count * 0.5)
|
||||
val_count = int(full_count * 0.4)
|
||||
test_count = int(full_count * 0.1)
|
||||
|
||||
train_dataset, val_dataset, test_dataset = random_split(
|
||||
full_dataset, (train_count, val_count, test_count))
|
||||
|
||||
# Dataloader
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
shuffle=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
shuffle=False,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
num_workers=0,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# HYPERPARAMETERS
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Select Initializer
|
||||
components_initializer = SMCI(full_dataset)
|
||||
|
||||
# Define Hyperparameters
|
||||
hyperparameters = GMLVQ.HyperParameters(
|
||||
lr=dict(components_layer=0.1, _omega=0),
|
||||
input_dim=4,
|
||||
distribution=dict(
|
||||
num_classes=3,
|
||||
per_class=1,
|
||||
),
|
||||
component_initializer=components_initializer,
|
||||
omega_initializer=PCALinearTransformInitializer,
|
||||
omega_initializer_kwargs=dict(
|
||||
data=train_dataset.dataset[train_dataset.indices][0]))
|
||||
|
||||
# Create Model
|
||||
model = GMLVQ(hyperparameters)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# TRAINING
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Controlling Callbacks
|
||||
recall = LogTorchmetricCallback(
|
||||
'training_recall',
|
||||
torchmetrics.Recall,
|
||||
num_classes=3,
|
||||
step=Steps.TRAINING,
|
||||
)
|
||||
|
||||
stopping_criterion = LogTorchmetricCallback(
|
||||
'validation_recall',
|
||||
torchmetrics.Recall,
|
||||
num_classes=3,
|
||||
step=Steps.VALIDATION,
|
||||
)
|
||||
|
||||
accuracy = LogTorchmetricCallback(
|
||||
'validation_accuracy',
|
||||
torchmetrics.Accuracy,
|
||||
num_classes=3,
|
||||
step=Steps.VALIDATION,
|
||||
)
|
||||
|
||||
es = EarlyStopping(
|
||||
monitor=stopping_criterion.name,
|
||||
mode="max",
|
||||
patience=10,
|
||||
)
|
||||
|
||||
# Visualization Callback
|
||||
vis = VisGMLVQ2D(data=full_dataset)
|
||||
|
||||
# Define trainer
|
||||
trainer = pl.Trainer(
|
||||
callbacks=[
|
||||
vis,
|
||||
recall,
|
||||
accuracy,
|
||||
stopping_criterion,
|
||||
es,
|
||||
PlotLambdaMatrixToTensorboard(),
|
||||
],
|
||||
max_epochs=100,
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
trainer.test(model, test_loader)
|
||||
|
||||
# Manual save
|
||||
trainer.save_checkpoint("./y_arch.ckpt")
|
||||
|
||||
# Load saved model
|
||||
new_model = GMLVQ.load_from_checkpoint(
|
||||
checkpoint_path="./y_arch.ckpt",
|
||||
strict=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,101 +0,0 @@
|
||||
"""GMLVQ example using the MNIST dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = MNIST(
|
||||
"~/datasets",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
]),
|
||||
)
|
||||
test_ds = MNIST(
|
||||
"~/datasets",
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
]),
|
||||
)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=256)
|
||||
test_loader = torch.utils.data.DataLoader(test_ds,
|
||||
num_workers=0,
|
||||
batch_size=256)
|
||||
|
||||
# Hyperparameters
|
||||
num_classes = 10
|
||||
prototypes_per_class = 10
|
||||
hparams = dict(
|
||||
input_dim=28 * 28,
|
||||
latent_dim=28 * 28,
|
||||
distribution=(num_classes, prototypes_per_class),
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.ImageGMLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisImgComp(
|
||||
data=train_ds,
|
||||
num_columns=10,
|
||||
show=False,
|
||||
tensorboard=True,
|
||||
random_data=100,
|
||||
add_embedding=True,
|
||||
embedding_data=200,
|
||||
flatten_data=False,
|
||||
)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.01,
|
||||
idle_epochs=1,
|
||||
prune_quota_per_epoch=10,
|
||||
frequency=1,
|
||||
verbose=True,
|
||||
)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=0.001,
|
||||
patience=15,
|
||||
mode="min",
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
pruning,
|
||||
# es,
|
||||
],
|
||||
terminate_on_nan=True,
|
||||
weights_summary=None,
|
||||
# accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,53 +0,0 @@
|
||||
"""Growing Neural Gas example using the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=42)
|
||||
|
||||
# Prepare the data
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
num_prototypes=5,
|
||||
input_dim=2,
|
||||
lr=0.1,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GrowingNeuralGas(
|
||||
hparams,
|
||||
prototypes_initializer=pt.initializers.ZCI(2),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisNG2D(data=train_loader)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
max_epochs=100,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,57 +0,0 @@
|
||||
"""k-NN example using the Iris dataset from scikit-learn."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(k=5)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.KNN(hparams, data=train_ds)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(
|
||||
data=(x_train, y_train),
|
||||
resolution=200,
|
||||
block=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
max_epochs=1,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
# This is only for visualization. k-NN has no training phase.
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# Recall
|
||||
y_pred = model.predict(torch.tensor(x_train))
|
||||
print(y_pred)
|
@@ -1,103 +0,0 @@
|
||||
"""Kohonen Self Organizing Map."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from prototorch.utils.colors import hex_to_rgb
|
||||
|
||||
|
||||
class Vis2DColorSOM(pl.Callback):
|
||||
def __init__(self, data, title="ColorSOMe", pause_time=0.1):
|
||||
super().__init__()
|
||||
self.title = title
|
||||
self.fig = plt.figure(self.title)
|
||||
self.data = data
|
||||
self.pause_time = pause_time
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
ax = self.fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(self.title)
|
||||
h, w = pl_module._grid.shape[:2]
|
||||
protos = pl_module.prototypes.view(h, w, 3)
|
||||
ax.imshow(protos)
|
||||
ax.axis("off")
|
||||
|
||||
# Overlay color names
|
||||
d = pl_module.compute_distances(self.data)
|
||||
wp = pl_module.predict_from_distances(d)
|
||||
for i, iloc in enumerate(wp):
|
||||
plt.text(iloc[1],
|
||||
iloc[0],
|
||||
cnames[i],
|
||||
ha="center",
|
||||
va="center",
|
||||
bbox=dict(facecolor="white", alpha=0.5, lw=0))
|
||||
|
||||
if trainer.current_epoch != trainer.max_epochs - 1:
|
||||
plt.pause(self.pause_time)
|
||||
else:
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=42)
|
||||
|
||||
# Prepare the data
|
||||
hex_colors = [
|
||||
"#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
|
||||
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
|
||||
"#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
|
||||
]
|
||||
cnames = [
|
||||
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
|
||||
"red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
|
||||
"lightgrey", "olive", "purple", "orange"
|
||||
]
|
||||
colors = list(hex_to_rgb(hex_colors))
|
||||
data = torch.Tensor(colors) / 255.0
|
||||
train_ds = torch.utils.data.TensorDataset(data)
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
shape=(18, 32),
|
||||
alpha=1.0,
|
||||
sigma=16,
|
||||
lr=0.1,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.KohonenSOM(
|
||||
hparams,
|
||||
prototypes_initializer=pt.initializers.RNCI(3),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 3)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = Vis2DColorSOM(data=data)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
max_epochs=500,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,68 +0,0 @@
|
||||
"""Localized-GMLVQ example using the Moons dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=2)
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
batch_size=256,
|
||||
shuffle=True)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[1, 3],
|
||||
input_dim=2,
|
||||
latent_dim=2,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.LGMLVQ(
|
||||
hparams,
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_acc",
|
||||
min_delta=0.001,
|
||||
patience=20,
|
||||
mode="max",
|
||||
verbose=False,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
es,
|
||||
],
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,90 +0,0 @@
|
||||
"""LVQMLN example using all four dimensions of the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class Backbone(torch.nn.Module):
|
||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.latent_size = latent_size
|
||||
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
|
||||
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.activation(self.dense1(x))
|
||||
out = self.activation(self.dense2(x))
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris()
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=42)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[3, 4, 5],
|
||||
proto_lr=0.001,
|
||||
bb_lr=0.001,
|
||||
)
|
||||
|
||||
# Initialize the backbone
|
||||
backbone = Backbone()
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.LVQMLN(
|
||||
hparams,
|
||||
prototypes_initializer=pt.initializers.SSCI(
|
||||
train_ds,
|
||||
transform=backbone,
|
||||
),
|
||||
backbone=backbone,
|
||||
)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisSiameseGLVQ2D(
|
||||
data=train_ds,
|
||||
map_protos=False,
|
||||
border=0.1,
|
||||
resolution=500,
|
||||
axis_off=True,
|
||||
)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.01,
|
||||
idle_epochs=20,
|
||||
prune_quota_per_epoch=2,
|
||||
frequency=10,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
pruning,
|
||||
],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,52 +0,0 @@
|
||||
"""Median-LVQ example using the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_ds,
|
||||
batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.MedianLVQ(
|
||||
hparams=dict(distribution=(3, 2), lr=0.01),
|
||||
prototypes_initializer=pt.initializers.SSCI(train_ds),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_acc",
|
||||
min_delta=0.01,
|
||||
patience=5,
|
||||
mode="max",
|
||||
verbose=True,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis, es],
|
||||
weights_summary="full",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,62 +0,0 @@
|
||||
"""Neural Gas example using the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare and pre-process the dataset
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
scaler = StandardScaler()
|
||||
scaler.fit(x_train)
|
||||
x_train = scaler.transform(x_train)
|
||||
|
||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
num_prototypes=30,
|
||||
input_dim=2,
|
||||
lr=0.03,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.NeuralGas(
|
||||
hparams,
|
||||
prototypes_initializer=pt.core.ZCI(2),
|
||||
lr_scheduler=ExponentialLR,
|
||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisNG2D(data=train_ds)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,61 +0,0 @@
|
||||
"""RSLVQ example using the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=42)
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[2, 2, 3],
|
||||
proto_lr=0.05,
|
||||
lambd=0.1,
|
||||
variance=1.0,
|
||||
input_dim=2,
|
||||
latent_dim=2,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.RSLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
terminate_on_nan=True,
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,72 +0,0 @@
|
||||
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class Backbone(torch.nn.Module):
|
||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.latent_size = latent_size
|
||||
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
|
||||
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.activation(self.dense1(x))
|
||||
out = self.activation(self.dense2(x))
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris()
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=2)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[1, 2, 3],
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the backbone
|
||||
backbone = Backbone()
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.SiameseGLVQ(
|
||||
hparams,
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
backbone=backbone,
|
||||
both_path_gradients=False,
|
||||
)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,103 +0,0 @@
|
||||
"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare the data
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||
|
||||
# Initialize the gng
|
||||
gng = pt.models.GrowingNeuralGas(
|
||||
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
|
||||
prototypes_initializer=pt.initializers.ZCI(2),
|
||||
lr_scheduler=ExponentialLR,
|
||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="loss",
|
||||
min_delta=0.001,
|
||||
patience=20,
|
||||
mode="min",
|
||||
verbose=False,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer for GNG
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=100,
|
||||
callbacks=[es],
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(gng, train_loader)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[],
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Warm-start prototypes
|
||||
knn = pt.models.KNN(dict(k=1), data=train_ds)
|
||||
prototypes = gng.prototypes
|
||||
plabels = knn.predict(prototypes)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototypes_initializer=pt.initializers.LCI(prototypes),
|
||||
labels_initializer=pt.initializers.LLI(plabels),
|
||||
lr_scheduler=ExponentialLR,
|
||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.02,
|
||||
idle_epochs=2,
|
||||
prune_quota_per_epoch=5,
|
||||
frequency=1,
|
||||
verbose=True,
|
||||
)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=0.001,
|
||||
patience=10,
|
||||
mode="min",
|
||||
verbose=True,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
pruning,
|
||||
es,
|
||||
],
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -1,24 +1,25 @@
|
||||
"""`models` plugin for the `prototorch` package."""
|
||||
|
||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||
from .cbc import CBC, ImageCBC
|
||||
from .glvq import (
|
||||
GLVQ,
|
||||
GLVQ1,
|
||||
GLVQ21,
|
||||
GMLVQ,
|
||||
GRLVQ,
|
||||
LGMLVQ,
|
||||
LVQMLN,
|
||||
ImageGLVQ,
|
||||
ImageGMLVQ,
|
||||
SiameseGLVQ,
|
||||
SiameseGMLVQ,
|
||||
from .architectures.base import BaseYArchitecture
|
||||
from .architectures.comparison import (
|
||||
OmegaComparisonMixin,
|
||||
SimpleComparisonMixin,
|
||||
)
|
||||
from .architectures.competition import WTACompetitionMixin
|
||||
from .architectures.components import SupervisedArchitecture
|
||||
from .architectures.loss import GLVQLossMixin
|
||||
from .architectures.optimization import (
|
||||
MultipleLearningRateMixin,
|
||||
SingleLearningRateMixin,
|
||||
)
|
||||
from .knn import KNN
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
|
||||
from .vis import *
|
||||
|
||||
__version__ = "0.3.0"
|
||||
__all__ = [
|
||||
'BaseYArchitecture',
|
||||
"OmegaComparisonMixin",
|
||||
"SimpleComparisonMixin",
|
||||
"SingleLearningRateMixin",
|
||||
"MultipleLearningRateMixin",
|
||||
"SupervisedArchitecture",
|
||||
"WTACompetitionMixin",
|
||||
"GLVQLossMixin",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0-a8"
|
||||
|
@@ -1,192 +0,0 @@
|
||||
"""Abstract classes to be inherited by prototorch models."""
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
from ..core.competitions import WTAC
|
||||
from ..core.components import Components, LabeledComponents
|
||||
from ..core.distances import euclidean_distance
|
||||
from ..core.initializers import LabelsInitializer
|
||||
from ..core.pooling import stratified_min_pooling
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
|
||||
|
||||
class ProtoTorchBolt(pl.LightningModule):
|
||||
"""All ProtoTorch models are ProtoTorch Bolts."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("lr", 0.01)
|
||||
|
||||
# Default config
|
||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
||||
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
||||
if self.lr_scheduler is not None:
|
||||
scheduler = self.lr_scheduler(optimizer,
|
||||
**self.lr_scheduler_kwargs)
|
||||
sch = {
|
||||
"scheduler": scheduler,
|
||||
"interval": "step",
|
||||
} # called after each training step
|
||||
return [optimizer], [sch]
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
def reconfigure_optimizers(self):
|
||||
self.trainer.accelerator.setup_optimizers(self.trainer)
|
||||
|
||||
def __repr__(self):
|
||||
surep = super().__repr__()
|
||||
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
||||
wrapped = f"ProtoTorch Bolt(\n{indented})"
|
||||
return wrapped
|
||||
|
||||
|
||||
class PrototypeModel(ProtoTorchBolt):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||
self.distance_layer = LambdaLayer(distance_fn)
|
||||
|
||||
@property
|
||||
def num_prototypes(self):
|
||||
return len(self.proto_layer.components)
|
||||
|
||||
@property
|
||||
def prototypes(self):
|
||||
return self.proto_layer.components.detach().cpu()
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Only an alias for the prototypes."""
|
||||
return self.prototypes
|
||||
|
||||
def add_prototypes(self, *args, **kwargs):
|
||||
self.proto_layer.add_components(*args, **kwargs)
|
||||
self.reconfigure_optimizers()
|
||||
|
||||
def remove_prototypes(self, indices):
|
||||
self.proto_layer.remove_components(indices)
|
||||
self.reconfigure_optimizers()
|
||||
|
||||
|
||||
class UnsupervisedPrototypeModel(PrototypeModel):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Layers
|
||||
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||
if prototypes_initializer is not None:
|
||||
self.proto_layer = Components(
|
||||
self.hparams.num_prototypes,
|
||||
initializer=prototypes_initializer,
|
||||
)
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos = self.proto_layer().type_as(x)
|
||||
distances = self.distance_layer(x, protos)
|
||||
return distances
|
||||
|
||||
def forward(self, x):
|
||||
distances = self.compute_distances(x)
|
||||
return distances
|
||||
|
||||
|
||||
class SupervisedPrototypeModel(PrototypeModel):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Layers
|
||||
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||
labels_initializer = kwargs.get("labels_initializer",
|
||||
LabelsInitializer())
|
||||
if prototypes_initializer is not None:
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=self.hparams.distribution,
|
||||
components_initializer=prototypes_initializer,
|
||||
labels_initializer=labels_initializer,
|
||||
)
|
||||
self.competition_layer = WTAC()
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.labels.detach().cpu()
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self.proto_layer.num_classes
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos)
|
||||
return distances
|
||||
|
||||
def forward(self, x):
|
||||
distances = self.compute_distances(x)
|
||||
_, plabels = self.proto_layer()
|
||||
winning = stratified_min_pooling(distances, plabels)
|
||||
y_pred = torch.nn.functional.softmin(winning)
|
||||
return y_pred
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
with torch.no_grad():
|
||||
_, plabels = self.proto_layer()
|
||||
y_pred = self.competition_layer(distances, plabels)
|
||||
return y_pred
|
||||
|
||||
def predict(self, x):
|
||||
with torch.no_grad():
|
||||
distances = self.compute_distances(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
return y_pred
|
||||
|
||||
def log_acc(self, distances, targets, tag):
|
||||
preds = self.predict_from_distances(distances)
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
|
||||
self.log(tag,
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
|
||||
|
||||
class ProtoTorchMixin(object):
|
||||
"""All mixins are ProtoTorchMixins."""
|
||||
pass
|
||||
|
||||
|
||||
class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||
"""Mixin for models with image prototypes."""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
||||
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
||||
from torchvision.utils import make_grid
|
||||
grid = make_grid(self.components, nrow=num_columns)
|
||||
if return_channels_last:
|
||||
grid = grid.permute((1, 2, 0))
|
||||
return grid.cpu()
|
290
prototorch/models/architectures/base.py
Normal file
290
prototorch/models/architectures/base.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Proto Y Architecture
|
||||
|
||||
Network architecture for Component based Learning.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torchmetrics import Metric
|
||||
|
||||
|
||||
class Steps(enumerate):
|
||||
TRAINING = "training"
|
||||
VALIDATION = "validation"
|
||||
TEST = "test"
|
||||
PREDICT = "predict"
|
||||
|
||||
|
||||
class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
@dataclass
|
||||
class HyperParameters:
|
||||
"""
|
||||
Add all hyperparameters in the inherited class.
|
||||
"""
|
||||
...
|
||||
|
||||
# Fields
|
||||
registered_metrics: dict[str, dict[type[Metric], Metric]] = {
|
||||
Steps.TRAINING: {},
|
||||
Steps.VALIDATION: {},
|
||||
Steps.TEST: {},
|
||||
}
|
||||
registered_metric_callbacks: dict[str, dict[type[Metric],
|
||||
set[Callable]]] = {
|
||||
Steps.TRAINING: {},
|
||||
Steps.VALIDATION: {},
|
||||
Steps.TEST: {},
|
||||
}
|
||||
|
||||
# Type Hints for Necessary Fields
|
||||
components_layer: torch.nn.Module
|
||||
|
||||
def __init__(self, hparams) -> None:
|
||||
if isinstance(hparams, dict):
|
||||
self.save_hyperparameters(hparams)
|
||||
# TODO: => Move into Component Child
|
||||
del hparams["initialized_proto_shape"]
|
||||
hparams = self.HyperParameters(**hparams)
|
||||
else:
|
||||
hparams_dict = asdict(hparams)
|
||||
hparams_dict["component_initializer"] = None
|
||||
self.save_hyperparameters(hparams_dict, )
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Common Steps
|
||||
self.init_components(hparams)
|
||||
self.init_backbone(hparams)
|
||||
self.init_comparison(hparams)
|
||||
self.init_competition(hparams)
|
||||
|
||||
# Train Steps
|
||||
self.init_loss(hparams)
|
||||
|
||||
# Inference Steps
|
||||
self.init_inference(hparams)
|
||||
|
||||
# external API
|
||||
def get_competition(self, batch, components):
|
||||
'''
|
||||
Returns the output of the competition layer.
|
||||
'''
|
||||
latent_batch, latent_components = self.backbone(batch, components)
|
||||
# TODO: => Latent Hook
|
||||
comparison_tensor = self.comparison(latent_batch, latent_components)
|
||||
# TODO: => Comparison Hook
|
||||
return comparison_tensor
|
||||
|
||||
def forward(self, batch):
|
||||
'''
|
||||
Returns the prediction.
|
||||
'''
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = (batch, None)
|
||||
# TODO: manage different datatypes?
|
||||
components = self.components_layer()
|
||||
# TODO: => Component Hook
|
||||
comparison_tensor = self.get_competition(batch, components)
|
||||
# TODO: => Competition Hook
|
||||
return self.inference(comparison_tensor, components)
|
||||
|
||||
def predict(self, batch):
|
||||
"""
|
||||
Alias for forward
|
||||
"""
|
||||
return self.forward(batch)
|
||||
|
||||
def forward_comparison(self, batch):
|
||||
'''
|
||||
Returns the Output of the comparison layer.
|
||||
'''
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = (batch, None)
|
||||
# TODO: manage different datatypes?
|
||||
components = self.components_layer()
|
||||
# TODO: => Component Hook
|
||||
return self.get_competition(batch, components)
|
||||
|
||||
def loss_forward(self, batch):
|
||||
'''
|
||||
Returns the output of the loss layer.
|
||||
'''
|
||||
# TODO: manage different datatypes?
|
||||
components = self.components_layer()
|
||||
# TODO: => Component Hook
|
||||
comparison_tensor = self.get_competition(batch, components)
|
||||
# TODO: => Competition Hook
|
||||
return self.loss(comparison_tensor, batch, components)
|
||||
|
||||
# Empty Initialization
|
||||
def init_components(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the components step.
|
||||
"""
|
||||
|
||||
def init_backbone(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the backbone step.
|
||||
"""
|
||||
|
||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the comparison step.
|
||||
"""
|
||||
|
||||
def init_competition(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the competition step.
|
||||
"""
|
||||
|
||||
def init_loss(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the loss step.
|
||||
"""
|
||||
|
||||
def init_inference(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the inference step.
|
||||
"""
|
||||
|
||||
# Empty Steps
|
||||
def components(self):
|
||||
"""
|
||||
This step has no input.
|
||||
|
||||
It returns the components.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The components step has no reasonable default.")
|
||||
|
||||
def backbone(self, batch, components):
|
||||
"""
|
||||
The backbone step receives the data batch and the components.
|
||||
It can transform both by an arbitrary function.
|
||||
|
||||
It returns the transformed batch and components,
|
||||
each of the same length as the original input.
|
||||
"""
|
||||
return batch, components
|
||||
|
||||
def comparison(self, batch, components):
|
||||
"""
|
||||
Takes a batch of size N and the component set of size M.
|
||||
|
||||
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The comparison step has no reasonable default.")
|
||||
|
||||
def competition(self, comparison_measures, components):
|
||||
"""
|
||||
Takes the tensor of comparison measures.
|
||||
|
||||
Assigns a competition vector to each class.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The competition step has no reasonable default.")
|
||||
|
||||
def loss(self, comparison_measures, batch, components):
|
||||
"""
|
||||
Takes the tensor of competition measures.
|
||||
|
||||
Calculates a single loss value
|
||||
"""
|
||||
raise NotImplementedError("The loss step has no reasonable default.")
|
||||
|
||||
def inference(self, comparison_measures, components):
|
||||
"""
|
||||
Takes the tensor of competition measures.
|
||||
|
||||
Returns the inferred vector.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The inference step has no reasonable default.")
|
||||
|
||||
# Y Architecture Hooks
|
||||
|
||||
# internal API, called by models and callbacks
|
||||
def register_torchmetric(
|
||||
self,
|
||||
name: Callable,
|
||||
metric: type[Metric],
|
||||
step: str = Steps.TRAINING,
|
||||
**metric_kwargs,
|
||||
):
|
||||
'''
|
||||
Register a callback for evaluating a torchmetric.
|
||||
'''
|
||||
if step == Steps.PREDICT:
|
||||
raise ValueError("Prediction metrics are not supported.")
|
||||
|
||||
if metric not in self.registered_metrics:
|
||||
self.registered_metrics[step][metric] = metric(**metric_kwargs)
|
||||
self.registered_metric_callbacks[step][metric] = {name}
|
||||
else:
|
||||
self.registered_metric_callbacks[step][metric].add(name)
|
||||
|
||||
def update_metrics_step(self, batch, step):
|
||||
# Prediction Metrics
|
||||
preds = self(batch)
|
||||
|
||||
_, y = batch
|
||||
for metric in self.registered_metrics[step]:
|
||||
instance = self.registered_metrics[step][metric].to(self.device)
|
||||
instance(y, preds.reshape(y.shape))
|
||||
|
||||
def update_metrics_epoch(self, step):
|
||||
for metric in self.registered_metrics[step]:
|
||||
instance = self.registered_metrics[step][metric].to(self.device)
|
||||
value = instance.compute()
|
||||
|
||||
for callback in self.registered_metric_callbacks[step][metric]:
|
||||
callback(value, self)
|
||||
|
||||
instance.reset()
|
||||
|
||||
# Lightning steps
|
||||
# -------------------------------------------------------------------------
|
||||
# >>>> Training
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
self.update_metrics_step(batch, Steps.TRAINING)
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def training_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.TRAINING)
|
||||
|
||||
# >>>> Validation
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.update_metrics_step(batch, Steps.VALIDATION)
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.VALIDATION)
|
||||
|
||||
# >>>> Test
|
||||
def test_step(self, batch, batch_idx):
|
||||
self.update_metrics_step(batch, Steps.TEST)
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def test_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.TEST)
|
||||
|
||||
# >>>> Prediction
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
return self.predict(batch)
|
||||
|
||||
# Check points
|
||||
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
||||
# Compatible with Lightning
|
||||
checkpoint["hyper_parameters"] = {
|
||||
'hparams': checkpoint["hyper_parameters"]
|
||||
}
|
||||
return super().on_save_checkpoint(checkpoint)
|
148
prototorch/models/architectures/comparison.py
Normal file
148
prototorch/models/architectures/comparison.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from prototorch.core.distances import euclidean_distance
|
||||
from prototorch.core.initializers import (
|
||||
AbstractLinearTransformInitializer,
|
||||
EyeLinearTransformInitializer,
|
||||
)
|
||||
from prototorch.models.architectures.base import BaseYArchitecture
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class SimpleComparisonMixin(BaseYArchitecture):
|
||||
"""
|
||||
Simple Comparison
|
||||
|
||||
A comparison layer that only uses the positions of the components
|
||||
and the batch for dissimilarity computation.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
|
||||
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
||||
"""
|
||||
comparison_fn: Callable = euclidean_distance
|
||||
comparison_args: dict = field(default_factory=dict)
|
||||
|
||||
comparison_parameters: dict = field(default_factory=dict)
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
def init_comparison(self, hparams: HyperParameters):
|
||||
self.comparison_layer = LambdaLayer(
|
||||
fn=hparams.comparison_fn,
|
||||
**hparams.comparison_args,
|
||||
)
|
||||
|
||||
self.comparison_kwargs: dict[str, Tensor] = {}
|
||||
|
||||
def comparison(self, batch, components):
|
||||
comp_tensor, _ = components
|
||||
batch_tensor, _ = batch
|
||||
|
||||
comp_tensor = comp_tensor.unsqueeze(1)
|
||||
|
||||
distances = self.comparison_layer(
|
||||
batch_tensor,
|
||||
comp_tensor,
|
||||
**self.comparison_kwargs,
|
||||
)
|
||||
|
||||
return distances
|
||||
|
||||
|
||||
class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||
"""
|
||||
Omega Comparison
|
||||
|
||||
A comparison layer that uses the positions of the components
|
||||
and the batch for dissimilarity computation.
|
||||
"""
|
||||
|
||||
_omega: torch.Tensor
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(SimpleComparisonMixin.HyperParameters):
|
||||
"""
|
||||
input_dim: Necessary Field: The dimensionality of the input.
|
||||
latent_dim:
|
||||
The dimensionality of the latent space. Default: 2.
|
||||
omega_initializer:
|
||||
The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
|
||||
"""
|
||||
input_dim: int | None = None
|
||||
latent_dim: int = 2
|
||||
omega_initializer: type[
|
||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||
omega_initializer_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||
super().init_comparison(hparams)
|
||||
|
||||
# Initialize the omega matrix
|
||||
if hparams.input_dim is None:
|
||||
raise ValueError("input_dim must be specified.")
|
||||
else:
|
||||
omega = hparams.omega_initializer(
|
||||
**hparams.omega_initializer_kwargs).generate(
|
||||
hparams.input_dim,
|
||||
hparams.latent_dim,
|
||||
)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
self.comparison_kwargs = dict(omega=self._omega)
|
||||
|
||||
# Properties
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
'''
|
||||
Omega Matrix. Mapping applied to data and prototypes.
|
||||
'''
|
||||
return self._omega.detach().cpu()
|
||||
|
||||
@property
|
||||
def lambda_matrix(self):
|
||||
'''
|
||||
Lambda Matrix.
|
||||
'''
|
||||
omega = self._omega.detach()
|
||||
lam = omega @ omega.T
|
||||
return lam.detach().cpu()
|
||||
|
||||
@property
|
||||
def relevance_profile(self):
|
||||
'''
|
||||
Relevance Profile. Main Diagonal of the Lambda Matrix.
|
||||
'''
|
||||
return self.lambda_matrix.diag().abs()
|
||||
|
||||
@property
|
||||
def classification_influence_profile(self):
|
||||
'''
|
||||
Classification Influence Profile. Influence of each dimension.
|
||||
'''
|
||||
lam = self.lambda_matrix
|
||||
return lam.abs().sum(0)
|
||||
|
||||
@property
|
||||
def parameter_omega(self):
|
||||
return self._omega
|
||||
|
||||
@parameter_omega.setter
|
||||
def parameter_omega(self, new_omega):
|
||||
with torch.no_grad():
|
||||
self._omega.data.copy_(new_omega)
|
29
prototorch/models/architectures/competition.py
Normal file
29
prototorch/models/architectures/competition.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from prototorch.core.competitions import WTAC
|
||||
from prototorch.models.architectures.base import BaseYArchitecture
|
||||
|
||||
|
||||
class WTACompetitionMixin(BaseYArchitecture):
|
||||
"""
|
||||
Winner Take All Competition
|
||||
|
||||
A competition layer that uses the winner-take-all strategy.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
No hyperparameters.
|
||||
"""
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def init_inference(self, hparams: HyperParameters):
|
||||
self.competition_layer = WTAC()
|
||||
|
||||
def inference(self, comparison_measures, components):
|
||||
comp_labels = components[1]
|
||||
return self.competition_layer(comparison_measures, comp_labels)
|
64
prototorch/models/architectures/components.py
Normal file
64
prototorch/models/architectures/components.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from prototorch.core.components import LabeledComponents
|
||||
from prototorch.core.initializers import (
|
||||
AbstractComponentsInitializer,
|
||||
LabelsInitializer,
|
||||
ZerosCompInitializer,
|
||||
)
|
||||
from prototorch.models import BaseYArchitecture
|
||||
|
||||
|
||||
class SupervisedArchitecture(BaseYArchitecture):
|
||||
"""
|
||||
Supervised Architecture
|
||||
|
||||
An architecture that uses labeled Components as component Layer.
|
||||
"""
|
||||
components_layer: LabeledComponents
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters:
|
||||
"""
|
||||
distribution: A valid prototype distribution. No default possible.
|
||||
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
|
||||
"""
|
||||
distribution: "dict[str, int]"
|
||||
component_initializer: AbstractComponentsInitializer
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def init_components(self, hparams: HyperParameters):
|
||||
if hparams.component_initializer is not None:
|
||||
self.components_layer = LabeledComponents(
|
||||
distribution=hparams.distribution,
|
||||
components_initializer=hparams.component_initializer,
|
||||
labels_initializer=LabelsInitializer(),
|
||||
)
|
||||
proto_shape = self.components_layer.components.shape[1:]
|
||||
self.hparams["initialized_proto_shape"] = proto_shape
|
||||
else:
|
||||
# when restoring a checkpointed model
|
||||
self.components_layer = LabeledComponents(
|
||||
distribution=hparams.distribution,
|
||||
components_initializer=ZerosCompInitializer(
|
||||
self.hparams["initialized_proto_shape"]),
|
||||
)
|
||||
|
||||
# Properties
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@property
|
||||
def prototypes(self):
|
||||
"""
|
||||
Returns the position of the prototypes.
|
||||
"""
|
||||
return self.components_layer.components.detach().cpu()
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
"""
|
||||
Returns the labels of the prototypes.
|
||||
"""
|
||||
return self.components_layer.labels.detach().cpu()
|
42
prototorch/models/architectures/loss.py
Normal file
42
prototorch/models/architectures/loss.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from prototorch.core.losses import GLVQLoss
|
||||
from prototorch.models.architectures.base import BaseYArchitecture
|
||||
|
||||
|
||||
class GLVQLossMixin(BaseYArchitecture):
|
||||
"""
|
||||
GLVQ Loss
|
||||
|
||||
A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
margin: The margin of the GLVQ loss. Default: 0.0.
|
||||
transfer_fn: Transfer function to use. Default: sigmoid_beta.
|
||||
transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}.
|
||||
"""
|
||||
margin: float = 0.0
|
||||
|
||||
transfer_fn: str = "sigmoid_beta"
|
||||
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def init_loss(self, hparams: HyperParameters):
|
||||
self.loss_layer = GLVQLoss(
|
||||
margin=hparams.margin,
|
||||
transfer_fn=hparams.transfer_fn,
|
||||
**hparams.transfer_args,
|
||||
)
|
||||
|
||||
def loss(self, comparison_measures, batch, components):
|
||||
target = batch[1]
|
||||
comp_labels = components[1]
|
||||
loss = self.loss_layer(comparison_measures, target, comp_labels)
|
||||
self.log('loss', loss)
|
||||
return loss
|
73
prototorch/models/architectures/optimization.py
Normal file
73
prototorch/models/architectures/optimization.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from prototorch.models import BaseYArchitecture
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class SingleLearningRateMixin(BaseYArchitecture):
|
||||
"""
|
||||
Single Learning Rate
|
||||
|
||||
All parameters are updated with a single learning rate.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
lr: The learning rate. Default: 0.1.
|
||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||
"""
|
||||
lr: float = 0.1
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Hooks
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def configure_optimizers(self):
|
||||
return self.hparams.optimizer(self.parameters(),
|
||||
lr=self.hparams.lr) # type: ignore
|
||||
|
||||
|
||||
class MultipleLearningRateMixin(BaseYArchitecture):
|
||||
"""
|
||||
Multiple Learning Rates
|
||||
|
||||
Define Different Learning Rates for different parameters.
|
||||
"""
|
||||
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(BaseYArchitecture.HyperParameters):
|
||||
"""
|
||||
lr: The learning rate. Default: 0.1.
|
||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||
"""
|
||||
lr: dict = field(default_factory=dict)
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Hooks
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def configure_optimizers(self):
|
||||
optimizers = []
|
||||
for name, lr in self.hparams.lr.items():
|
||||
if not hasattr(self, name):
|
||||
raise ValueError(f"{name} is not a parameter of {self}")
|
||||
else:
|
||||
model_part = getattr(self, name)
|
||||
if isinstance(model_part, Parameter):
|
||||
optimizers.append(
|
||||
self.hparams.optimizer(
|
||||
[model_part],
|
||||
lr=lr, # type: ignore
|
||||
))
|
||||
elif hasattr(model_part, "parameters"):
|
||||
optimizers.append(
|
||||
self.hparams.optimizer(
|
||||
model_part.parameters(),
|
||||
lr=lr, # type: ignore
|
||||
))
|
||||
return optimizers
|
@@ -1,137 +1,307 @@
|
||||
"""Lightning Callbacks."""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Optional, Type
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.models.architectures.base import BaseYArchitecture, Steps
|
||||
from prototorch.models.architectures.comparison import OmegaComparisonMixin
|
||||
from prototorch.models.library.gmlvq import GMLVQ
|
||||
from prototorch.models.vis import Vis2DAbstract
|
||||
from prototorch.utils.utils import mesh2d
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
||||
from ..core.components import Components
|
||||
from ..core.initializers import LiteralCompInitializer
|
||||
from .extras import ConnectionTopology
|
||||
DIVERGING_COLOR_MAPS = [
|
||||
'PiYG',
|
||||
'PRGn',
|
||||
'BrBG',
|
||||
'PuOr',
|
||||
'RdGy',
|
||||
'RdBu',
|
||||
'RdYlBu',
|
||||
'RdYlGn',
|
||||
'Spectral',
|
||||
'coolwarm',
|
||||
'bwr',
|
||||
'seismic',
|
||||
]
|
||||
|
||||
|
||||
class PruneLoserPrototypes(pl.Callback):
|
||||
def __init__(self,
|
||||
threshold=0.01,
|
||||
idle_epochs=10,
|
||||
prune_quota_per_epoch=-1,
|
||||
frequency=1,
|
||||
replace=False,
|
||||
prototypes_initializer=None,
|
||||
verbose=False):
|
||||
self.threshold = threshold # minimum win ratio
|
||||
self.idle_epochs = idle_epochs # epochs to wait before pruning
|
||||
self.prune_quota_per_epoch = prune_quota_per_epoch
|
||||
self.frequency = frequency
|
||||
self.replace = replace
|
||||
self.verbose = verbose
|
||||
self.prototypes_initializer = prototypes_initializer
|
||||
class LogTorchmetricCallback(pl.Callback):
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if (trainer.current_epoch + 1) < self.idle_epochs:
|
||||
return None
|
||||
if (trainer.current_epoch + 1) % self.frequency:
|
||||
return None
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
metric: Type[torchmetrics.Metric],
|
||||
step: str = Steps.TRAINING,
|
||||
on_epoch=True,
|
||||
**metric_kwargs,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.metric = metric
|
||||
self.metric_kwargs = metric_kwargs
|
||||
self.step = step
|
||||
self.on_epoch = on_epoch
|
||||
|
||||
ratios = pl_module.prototype_win_ratios.mean(dim=0)
|
||||
to_prune = torch.arange(len(ratios))[ratios < self.threshold]
|
||||
to_prune = to_prune.tolist()
|
||||
prune_labels = pl_module.prototype_labels[to_prune]
|
||||
if self.prune_quota_per_epoch > 0:
|
||||
to_prune = to_prune[:self.prune_quota_per_epoch]
|
||||
prune_labels = prune_labels[:self.prune_quota_per_epoch]
|
||||
def setup(
|
||||
self,
|
||||
trainer: pl.Trainer,
|
||||
pl_module: BaseYArchitecture,
|
||||
stage: Optional[str] = None,
|
||||
) -> None:
|
||||
pl_module.register_torchmetric(
|
||||
self,
|
||||
self.metric,
|
||||
step=self.step,
|
||||
**self.metric_kwargs,
|
||||
)
|
||||
|
||||
if len(to_prune) > 0:
|
||||
if self.verbose:
|
||||
print(f"\nPrototype win ratios: {ratios}")
|
||||
print(f"Pruning prototypes at: {to_prune}")
|
||||
print(f"Corresponding labels are: {prune_labels.tolist()}")
|
||||
cur_num_protos = pl_module.num_prototypes
|
||||
pl_module.remove_prototypes(indices=to_prune)
|
||||
if self.replace:
|
||||
labels, counts = torch.unique(prune_labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
distribution = dict(zip(labels.tolist(), counts.tolist()))
|
||||
if self.verbose:
|
||||
print(f"Re-adding pruned prototypes...")
|
||||
print(f"distribution={distribution}")
|
||||
pl_module.add_prototypes(
|
||||
distribution=distribution,
|
||||
components_initializer=self.prototypes_initializer)
|
||||
new_num_protos = pl_module.num_prototypes
|
||||
if self.verbose:
|
||||
print(f"`num_prototypes` changed from {cur_num_protos} "
|
||||
f"to {new_num_protos}.")
|
||||
return True
|
||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||
pl_module.log(
|
||||
self.name,
|
||||
value,
|
||||
on_epoch=self.on_epoch,
|
||||
on_step=(not self.on_epoch),
|
||||
)
|
||||
|
||||
|
||||
class PrototypeConvergence(pl.Callback):
|
||||
def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False):
|
||||
self.min_delta = min_delta
|
||||
self.idle_epochs = idle_epochs # epochs to wait
|
||||
self.verbose = verbose
|
||||
class LogConfusionMatrix(LogTorchmetricCallback):
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if (trainer.current_epoch + 1) < self.idle_epochs:
|
||||
return None
|
||||
if self.verbose:
|
||||
print("Stopping...")
|
||||
# TODO
|
||||
return True
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
name="confusion",
|
||||
on='prediction',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
torchmetrics.ConfusionMatrix,
|
||||
on=on,
|
||||
num_classes=num_classes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(value.detach().cpu().numpy())
|
||||
|
||||
# Show all ticks and label them with the respective list entries
|
||||
# ax.set_xticks(np.arange(len(farmers)), labels=farmers)
|
||||
# ax.set_yticks(np.arange(len(vegetables)), labels=vegetables)
|
||||
|
||||
# Rotate the tick labels and set their alignment.
|
||||
plt.setp(
|
||||
ax.get_xticklabels(),
|
||||
rotation=45,
|
||||
ha="right",
|
||||
rotation_mode="anchor",
|
||||
)
|
||||
|
||||
# Loop over data dimensions and create text annotations.
|
||||
for i in range(len(value)):
|
||||
for j in range(len(value)):
|
||||
text = ax.text(
|
||||
j,
|
||||
i,
|
||||
value[i, j].item(),
|
||||
ha="center",
|
||||
va="center",
|
||||
color="w",
|
||||
)
|
||||
|
||||
ax.set_title(self.name)
|
||||
fig.tight_layout()
|
||||
|
||||
pl_module.logger.experiment.add_figure(
|
||||
tag=self.name,
|
||||
figure=fig,
|
||||
close=True,
|
||||
global_step=pl_module.global_step,
|
||||
)
|
||||
|
||||
|
||||
class GNGCallback(pl.Callback):
|
||||
"""GNG Callback.
|
||||
class VisGLVQ2D(Vis2DAbstract):
|
||||
|
||||
Applies growing algorithm based on accumulated error and topology.
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
ax = self.setup_ax()
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
if x_train is not None:
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
mesh_input, xx, yy = mesh2d(
|
||||
np.vstack([x_train, protos]),
|
||||
self.border,
|
||||
self.resolution,
|
||||
)
|
||||
else:
|
||||
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
||||
_components = pl_module.components_layer.components
|
||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||
y_pred = pl_module.predict(mesh_input)
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
||||
|
||||
"""
|
||||
def __init__(self, reduction=0.1, freq=10):
|
||||
self.reduction = reduction
|
||||
self.freq = freq
|
||||
class VisGMLVQ2D(Vis2DAbstract):
|
||||
|
||||
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
|
||||
if (trainer.current_epoch + 1) % self.freq == 0:
|
||||
# Get information
|
||||
errors = pl_module.errors
|
||||
topology: ConnectionTopology = pl_module.topology_layer
|
||||
components: Components = pl_module.proto_layer.components
|
||||
def __init__(self, *args, ev_proj=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ev_proj = ev_proj
|
||||
|
||||
# Insertion point
|
||||
worst = torch.argmax(errors)
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
device = pl_module.device
|
||||
omega = pl_module._omega.detach()
|
||||
lam = omega @ omega.T
|
||||
u, _, _ = torch.pca_lowrank(lam, q=2)
|
||||
with torch.no_grad():
|
||||
x_train = torch.Tensor(x_train).to(device)
|
||||
x_train = x_train @ u
|
||||
x_train = x_train.cpu().detach()
|
||||
if self.show_protos:
|
||||
with torch.no_grad():
|
||||
protos = torch.Tensor(protos).to(device)
|
||||
protos = protos @ u
|
||||
protos = protos.cpu().detach()
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
if self.show_protos:
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
|
||||
neighbors = topology.get_neighbors(worst)[0]
|
||||
|
||||
if len(neighbors) == 0:
|
||||
logging.log(level=20, msg="No neighbor-pairs found!")
|
||||
return
|
||||
class PlotLambdaMatrixToTensorboard(pl.Callback):
|
||||
|
||||
neighbors_errors = errors[neighbors]
|
||||
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
|
||||
def __init__(self, cmap='seismic') -> None:
|
||||
super().__init__()
|
||||
self.cmap = cmap
|
||||
|
||||
# New Prototype
|
||||
new_component = 0.5 * (components[worst] +
|
||||
components[worst_neighbor])
|
||||
if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str:
|
||||
warnings.warn(
|
||||
f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}"
|
||||
)
|
||||
|
||||
# Add component
|
||||
pl_module.proto_layer.add_components(
|
||||
None,
|
||||
initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
|
||||
def on_train_start(self, trainer, pl_module: GMLVQ):
|
||||
self.plot_lambda(trainer, pl_module)
|
||||
|
||||
# Adjust Topology
|
||||
topology.add_prototype()
|
||||
topology.add_connection(worst, -1)
|
||||
topology.add_connection(worst_neighbor, -1)
|
||||
topology.remove_connection(worst, worst_neighbor)
|
||||
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
|
||||
self.plot_lambda(trainer, pl_module)
|
||||
|
||||
# New errors
|
||||
worst_error = errors[worst].unsqueeze(0)
|
||||
pl_module.errors = torch.cat([pl_module.errors, worst_error])
|
||||
pl_module.errors[worst] = errors[worst] * self.reduction
|
||||
pl_module.errors[
|
||||
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
||||
def plot_lambda(self, trainer, pl_module: GMLVQ):
|
||||
|
||||
trainer.accelerator.setup_optimizers(trainer)
|
||||
self.fig, self.ax = plt.subplots(1, 1)
|
||||
|
||||
# plot lambda matrix
|
||||
l_matrix = pl_module.lambda_matrix
|
||||
|
||||
# normalize lambda matrix
|
||||
l_matrix = l_matrix / torch.max(torch.abs(l_matrix))
|
||||
|
||||
# plot lambda matrix
|
||||
self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1)
|
||||
|
||||
self.fig.colorbar(self.ax.images[-1])
|
||||
|
||||
# add title
|
||||
self.ax.set_title('Lambda Matrix')
|
||||
|
||||
# add to tensorboard
|
||||
if isinstance(trainer.logger, TensorBoardLogger):
|
||||
trainer.logger.experiment.add_figure(
|
||||
"lambda_matrix",
|
||||
self.fig,
|
||||
trainer.global_step,
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
|
||||
)
|
||||
|
||||
|
||||
class Profiles(Enum):
|
||||
'''
|
||||
Available Profiles
|
||||
'''
|
||||
RELEVANCE = 'relevance'
|
||||
INFLUENCE = 'influence'
|
||||
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class PlotMatrixProfiles(pl.Callback):
|
||||
|
||||
def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None:
|
||||
super().__init__()
|
||||
self.cmap = cmap
|
||||
self.profile = profile
|
||||
|
||||
def on_train_start(self, trainer, pl_module: GMLVQ):
|
||||
'''
|
||||
Plot initial profile.
|
||||
'''
|
||||
self._plot_profile(trainer, pl_module)
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
|
||||
'''
|
||||
Plot after every epoch.
|
||||
'''
|
||||
self._plot_profile(trainer, pl_module)
|
||||
|
||||
def _plot_profile(self, trainer, pl_module: GMLVQ):
|
||||
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
|
||||
# plot lambda matrix
|
||||
l_matrix = torch.abs(pl_module.lambda_matrix)
|
||||
|
||||
if self.profile == Profiles.RELEVANCE:
|
||||
profile_value = l_matrix.diag()
|
||||
elif self.profile == Profiles.INFLUENCE:
|
||||
profile_value = l_matrix.sum(0)
|
||||
|
||||
# plot lambda matrix
|
||||
ax.plot(profile_value.detach().numpy())
|
||||
|
||||
# add title
|
||||
ax.set_title(f'{self.profile} profile')
|
||||
|
||||
# add to tensorboard
|
||||
if isinstance(trainer.logger, TensorBoardLogger):
|
||||
trainer.logger.experiment.add_figure(
|
||||
f"{self.profile}_matrix",
|
||||
fig,
|
||||
trainer.global_step,
|
||||
)
|
||||
else:
|
||||
class_name = self.__class__.__name__
|
||||
logger_name = trainer.logger.__class__.__name__
|
||||
warnings.warn(
|
||||
f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead."
|
||||
)
|
||||
|
||||
|
||||
class OmegaTraceNormalization(pl.Callback):
|
||||
'''
|
||||
Trace normalization of the Omega Matrix.
|
||||
'''
|
||||
__epsilon = torch.finfo(torch.float32).eps
|
||||
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer",
|
||||
pl_module: OmegaComparisonMixin) -> None:
|
||||
|
||||
omega = pl_module.parameter_omega
|
||||
denominator = torch.sqrt(torch.trace(omega.T @ omega))
|
||||
logging.debug(
|
||||
"Apply Omega Trace Normalization: demoninator=%f",
|
||||
denominator.item(),
|
||||
)
|
||||
pl_module.parameter_omega = omega / (denominator + self.__epsilon)
|
||||
|
@@ -1,77 +0,0 @@
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
from ..core.competitions import CBCC
|
||||
from ..core.components import ReasoningComponents
|
||||
from ..core.initializers import RandomReasoningsInitializer
|
||||
from ..core.losses import MarginLoss
|
||||
from ..core.similarities import euclidean_similarity
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
from .abstract import ImagePrototypesMixin
|
||||
from .glvq import SiameseGLVQ
|
||||
|
||||
|
||||
class CBC(SiameseGLVQ):
|
||||
"""Classification-By-Components."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
||||
components_initializer = kwargs.get("components_initializer", None)
|
||||
reasonings_initializer = kwargs.get("reasonings_initializer",
|
||||
RandomReasoningsInitializer())
|
||||
self.components_layer = ReasoningComponents(
|
||||
self.hparams.distribution,
|
||||
components_initializer=components_initializer,
|
||||
reasonings_initializer=reasonings_initializer,
|
||||
)
|
||||
self.similarity_layer = LambdaLayer(similarity_fn)
|
||||
self.competition_layer = CBCC()
|
||||
|
||||
# Namespace hook
|
||||
self.proto_layer = self.components_layer
|
||||
|
||||
self.loss = MarginLoss(self.hparams.margin)
|
||||
|
||||
def forward(self, x):
|
||||
components, reasonings = self.components_layer()
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_components = self.backbone(components)
|
||||
self.backbone.requires_grad_(True)
|
||||
detections = self.similarity_layer(latent_x, latent_components)
|
||||
probs = self.competition_layer(detections, reasonings)
|
||||
return probs
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
y_pred = self(x)
|
||||
num_classes = self.num_classes
|
||||
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
|
||||
loss = self.loss(y_pred, y_true).mean()
|
||||
return y_pred, loss
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
preds = torch.argmax(y_pred, dim=1)
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(),
|
||||
batch[1].int())
|
||||
self.log("train_acc",
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
return train_loss
|
||||
|
||||
def predict(self, x):
|
||||
with torch.no_grad():
|
||||
y_pred = self(x)
|
||||
y_pred = torch.argmax(y_pred, dim=1)
|
||||
return y_pred
|
||||
|
||||
|
||||
class ImageCBC(ImagePrototypesMixin, CBC):
|
||||
"""CBC model that constrains the components to the range [0, 1] by
|
||||
clamping after updates.
|
||||
"""
|
@@ -1,123 +0,0 @@
|
||||
"""Prototorch Data Modules
|
||||
|
||||
This allows to store the used dataset inside a Lightning Module.
|
||||
Mainly used for PytorchLightningCLI configurations.
|
||||
"""
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
|
||||
# MNIST
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def __init__(self, batch_size=32):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
|
||||
# Download mnist dataset as side-effect, only called on the first cpu
|
||||
def prepare_data(self):
|
||||
MNIST("~/datasets", train=True, download=True)
|
||||
MNIST("~/datasets", train=False, download=True)
|
||||
|
||||
# called for every GPU/machine (assigning state is OK)
|
||||
def setup(self, stage=None):
|
||||
# Transforms
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
# Split dataset
|
||||
if stage in (None, "fit"):
|
||||
mnist_train = MNIST("~/datasets", train=True, transform=transform)
|
||||
self.mnist_train, self.mnist_val = random_split(
|
||||
mnist_train,
|
||||
[55000, 5000],
|
||||
)
|
||||
if stage == (None, "test"):
|
||||
self.mnist_test = MNIST(
|
||||
"~/datasets",
|
||||
train=False,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
# Dataloaders
|
||||
def train_dataloader(self):
|
||||
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
|
||||
return mnist_train
|
||||
|
||||
def val_dataloader(self):
|
||||
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
|
||||
return mnist_val
|
||||
|
||||
def test_dataloader(self):
|
||||
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
|
||||
return mnist_test
|
||||
|
||||
|
||||
# def train_on_mnist(batch_size=256) -> type:
|
||||
# class DataClass(pl.LightningModule):
|
||||
# datamodule = MNISTDataModule(batch_size=batch_size)
|
||||
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# prototype_initializer = kwargs.pop(
|
||||
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
|
||||
# super().__init__(*args,
|
||||
# prototype_initializer=prototype_initializer,
|
||||
# **kwargs)
|
||||
|
||||
# dc: Type[DataClass] = DataClass
|
||||
# return dc
|
||||
|
||||
|
||||
# ABSTRACT
|
||||
class GeneralDataModule(pl.LightningDataModule):
|
||||
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
|
||||
super().__init__()
|
||||
self.train_dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return DataLoader(self.train_dataset, batch_size=self.batch_size)
|
||||
|
||||
|
||||
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
|
||||
# class DataClass(pl.LightningModule):
|
||||
# datamodule = GeneralDataModule(dataset, batch_size)
|
||||
# datashape = dataset[0][0].shape
|
||||
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
|
||||
|
||||
# def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
# prototype_initializer = kwargs.pop(
|
||||
# "prototype_initializer",
|
||||
# pt.components.Zeros(self.datashape),
|
||||
# )
|
||||
# super().__init__(*args,
|
||||
# prototype_initializer=prototype_initializer,
|
||||
# **kwargs)
|
||||
|
||||
# return DataClass
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# from prototorch.models import GLVQ
|
||||
|
||||
# demo_dataset = pt.datasets.Iris()
|
||||
|
||||
# TrainingClass: Type = train_on_dataset(demo_dataset)
|
||||
|
||||
# class DemoGLVQ(TrainingClass, GLVQ):
|
||||
# """Model Definition."""
|
||||
|
||||
# # Hyperparameters
|
||||
# hparams = dict(
|
||||
# distribution={
|
||||
# "num_classes": 3,
|
||||
# "prototypes_per_class": 4
|
||||
# },
|
||||
# lr=0.01,
|
||||
# )
|
||||
|
||||
# initialized = DemoGLVQ(hparams)
|
||||
# print(initialized)
|
@@ -1,90 +0,0 @@
|
||||
"""prototorch.models.extras
|
||||
|
||||
Modules not yet available in prototorch go here temporarily.
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from ..core.similarities import gaussian
|
||||
|
||||
|
||||
def rank_scaled_gaussian(distances, lambd):
|
||||
order = torch.argsort(distances, dim=1)
|
||||
ranks = torch.argsort(order, dim=1)
|
||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||
|
||||
|
||||
class GaussianPrior(torch.nn.Module):
|
||||
def __init__(self, variance):
|
||||
super().__init__()
|
||||
self.variance = variance
|
||||
|
||||
def forward(self, distances):
|
||||
return gaussian(distances, self.variance)
|
||||
|
||||
|
||||
class RankScaledGaussianPrior(torch.nn.Module):
|
||||
def __init__(self, lambd):
|
||||
super().__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
def forward(self, distances):
|
||||
return rank_scaled_gaussian(distances, self.lambd)
|
||||
|
||||
|
||||
class ConnectionTopology(torch.nn.Module):
|
||||
def __init__(self, agelimit, num_prototypes):
|
||||
super().__init__()
|
||||
self.agelimit = agelimit
|
||||
self.num_prototypes = num_prototypes
|
||||
|
||||
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
|
||||
self.age = torch.zeros_like(self.cmat)
|
||||
|
||||
def forward(self, d):
|
||||
order = torch.argsort(d, dim=1)
|
||||
|
||||
for element in order:
|
||||
i0, i1 = element[0], element[1]
|
||||
|
||||
self.cmat[i0][i1] = 1
|
||||
self.cmat[i1][i0] = 1
|
||||
|
||||
self.age[i0][i1] = 0
|
||||
self.age[i1][i0] = 0
|
||||
|
||||
self.age[i0][self.cmat[i0] == 1] += 1
|
||||
self.age[i1][self.cmat[i1] == 1] += 1
|
||||
|
||||
self.cmat[i0][self.age[i0] > self.agelimit] = 0
|
||||
self.cmat[i1][self.age[i1] > self.agelimit] = 0
|
||||
|
||||
def get_neighbors(self, position):
|
||||
return torch.where(self.cmat[position])
|
||||
|
||||
def add_prototype(self):
|
||||
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
|
||||
new_cmat[:-1, :-1] = self.cmat
|
||||
self.cmat = new_cmat
|
||||
|
||||
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
|
||||
new_age[:-1, :-1] = self.age
|
||||
self.age = new_age
|
||||
|
||||
def add_connection(self, a, b):
|
||||
self.cmat[a][b] = 1
|
||||
self.cmat[b][a] = 1
|
||||
|
||||
self.age[a][b] = 0
|
||||
self.age[b][a] = 0
|
||||
|
||||
def remove_connection(self, a, b):
|
||||
self.cmat[a][b] = 0
|
||||
self.cmat[b][a] = 0
|
||||
|
||||
self.age[a][b] = 0
|
||||
self.age[b][a] = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(agelimit): ({self.agelimit})"
|
@@ -1,310 +0,0 @@
|
||||
"""Models based on the GLVQ framework."""
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ..core.competitions import wtac
|
||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
||||
from ..core.initializers import EyeTransformInitializer
|
||||
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
|
||||
from ..core.transforms import LinearTransform
|
||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||
|
||||
|
||||
class GLVQ(SupervisedPrototypeModel):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("margin", 0.0)
|
||||
self.hparams.setdefault("transfer_fn", "identity")
|
||||
self.hparams.setdefault("transfer_beta", 10.0)
|
||||
|
||||
# Loss
|
||||
self.loss = GLVQLoss(
|
||||
margin=self.hparams.margin,
|
||||
transfer_fn=self.hparams.transfer_fn,
|
||||
beta=self.hparams.transfer_beta,
|
||||
)
|
||||
|
||||
def initialize_prototype_win_ratios(self):
|
||||
self.register_buffer(
|
||||
"prototype_win_ratios",
|
||||
torch.zeros(self.num_prototypes, device=self.device))
|
||||
|
||||
def on_epoch_start(self):
|
||||
self.initialize_prototype_win_ratios()
|
||||
|
||||
def log_prototype_win_ratios(self, distances):
|
||||
batch_size = len(distances)
|
||||
prototype_wc = torch.zeros(self.num_prototypes,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
wi, wc = torch.unique(distances.min(dim=-1).indices,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
prototype_wc[wi] = wc
|
||||
prototype_wr = prototype_wc / batch_size
|
||||
self.prototype_win_ratios = torch.vstack([
|
||||
self.prototype_win_ratios,
|
||||
prototype_wr,
|
||||
])
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self.compute_distances(x)
|
||||
_, plabels = self.proto_layer()
|
||||
loss = self.loss(out, y, plabels)
|
||||
return out, loss
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
self.log_prototype_win_ratios(out)
|
||||
self.log("train_loss", train_loss)
|
||||
self.log_acc(out, batch[-1], tag="train_acc")
|
||||
return train_loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# `model.eval()` and `torch.no_grad()` handled by pl
|
||||
out, val_loss = self.shared_step(batch, batch_idx)
|
||||
self.log("val_loss", val_loss)
|
||||
self.log_acc(out, batch[-1], tag="val_acc")
|
||||
return val_loss
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
# `model.eval()` and `torch.no_grad()` handled by pl
|
||||
out, test_loss = self.shared_step(batch, batch_idx)
|
||||
self.log_acc(out, batch[-1], tag="test_acc")
|
||||
return test_loss
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
test_loss = 0.0
|
||||
for batch_loss in outputs:
|
||||
test_loss += batch_loss.item()
|
||||
self.log("test_loss", test_loss)
|
||||
|
||||
# TODO
|
||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
# pass
|
||||
|
||||
|
||||
class SiameseGLVQ(GLVQ):
|
||||
"""GLVQ in a Siamese setting.
|
||||
|
||||
GLVQ model that applies an arbitrary transformation on the inputs and the
|
||||
prototypes before computing the distances between them. The weights in the
|
||||
transformation pipeline are only learned from the inputs.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
hparams,
|
||||
backbone=torch.nn.Identity(),
|
||||
both_path_gradients=False,
|
||||
**kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
self.backbone = backbone
|
||||
self.both_path_gradients = both_path_gradients
|
||||
|
||||
def configure_optimizers(self):
|
||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||
lr=self.hparams.proto_lr)
|
||||
# Only add a backbone optimizer if backbone has trainable parameters
|
||||
bb_params = list(self.backbone.parameters())
|
||||
if (bb_params):
|
||||
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
|
||||
optimizers = [proto_opt, bb_opt]
|
||||
else:
|
||||
optimizers = [proto_opt]
|
||||
if self.lr_scheduler is not None:
|
||||
schedulers = []
|
||||
for optimizer in optimizers:
|
||||
scheduler = self.lr_scheduler(optimizer,
|
||||
**self.lr_scheduler_kwargs)
|
||||
schedulers.append(scheduler)
|
||||
return optimizers, schedulers
|
||||
else:
|
||||
return optimizers
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
self.backbone.requires_grad_(True)
|
||||
distances = self.distance_layer(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
def predict_latent(self, x, map_protos=True):
|
||||
"""Predict `x` assuming it is already embedded in the latent space.
|
||||
|
||||
Only the prototypes are embedded in the latent space using the
|
||||
backbone.
|
||||
|
||||
"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
protos, plabels = self.proto_layer()
|
||||
if map_protos:
|
||||
protos = self.backbone(protos)
|
||||
d = self.distance_layer(x, protos)
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred
|
||||
|
||||
|
||||
class LVQMLN(SiameseGLVQ):
|
||||
"""Learning Vector Quantization Multi-Layer Network.
|
||||
|
||||
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
|
||||
on the prototypes before computing the distances between them. This of
|
||||
course, means that the prototypes no longer live the input space, but
|
||||
rather in the embedding space.
|
||||
|
||||
"""
|
||||
def compute_distances(self, x):
|
||||
latent_protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
distances = self.distance_layer(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
|
||||
class GRLVQ(SiameseGLVQ):
|
||||
"""Generalized Relevance Learning Vector Quantization.
|
||||
|
||||
Implemented as a Siamese network with a linear transformation backbone.
|
||||
|
||||
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Additional parameters
|
||||
relevances = torch.ones(self.hparams.input_dim, device=self.device)
|
||||
self.register_parameter("_relevances", Parameter(relevances))
|
||||
|
||||
# Override the backbone
|
||||
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
|
||||
name="relevance scaling")
|
||||
|
||||
@property
|
||||
def relevance_profile(self):
|
||||
return self._relevances.detach().cpu()
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(relevances): (shape: {tuple(self._relevances.shape)})"
|
||||
|
||||
|
||||
class SiameseGMLVQ(SiameseGLVQ):
|
||||
"""Generalized Matrix Learning Vector Quantization.
|
||||
|
||||
Implemented as a Siamese network with a linear transformation backbone.
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Override the backbone
|
||||
omega_initializer = kwargs.get("omega_initializer",
|
||||
EyeTransformInitializer())
|
||||
self.backbone = LinearTransform(
|
||||
self.hparams.input_dim,
|
||||
self.hparams.output_dim,
|
||||
initializer=omega_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
return self.backbone.weights
|
||||
|
||||
@property
|
||||
def lambda_matrix(self):
|
||||
omega = self.backbone.weight # (input_dim, latent_dim)
|
||||
lam = omega @ omega.T
|
||||
return lam.detach().cpu()
|
||||
|
||||
|
||||
class GMLVQ(GLVQ):
|
||||
"""Generalized Matrix Learning Vector Quantization.
|
||||
|
||||
Implemented as a regular GLVQ network that simply uses a different distance
|
||||
function. This makes it easier to implement a localized variant.
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
# Additional parameters
|
||||
omega_initializer = kwargs.get("omega_initializer",
|
||||
EyeTransformInitializer())
|
||||
omega = omega_initializer.generate(self.hparams.input_dim,
|
||||
self.hparams.latent_dim)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
self.backbone = LambdaLayer(lambda x: x @ self._omega,
|
||||
name="omega matrix")
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
return self._omega.detach().cpu()
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos, self._omega)
|
||||
return distances
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
||||
|
||||
|
||||
class LGMLVQ(GMLVQ):
|
||||
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", lomega_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
# Re-register `_omega` to override the one from the super class.
|
||||
omega = torch.randn(
|
||||
self.num_prototypes,
|
||||
self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
device=self.device,
|
||||
)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
|
||||
|
||||
class GLVQ1(GLVQ):
|
||||
"""Generalized Learning Vector Quantization 1."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = LossLayer(lvq1_loss)
|
||||
self.optimizer = torch.optim.SGD
|
||||
|
||||
|
||||
class GLVQ21(GLVQ):
|
||||
"""Generalized Learning Vector Quantization 2.1."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = LossLayer(lvq21_loss)
|
||||
self.optimizer = torch.optim.SGD
|
||||
|
||||
|
||||
class ImageGLVQ(ImagePrototypesMixin, GLVQ):
|
||||
"""GLVQ for training on image data.
|
||||
|
||||
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
after updates.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
|
||||
"""GMLVQ for training on image data.
|
||||
|
||||
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
after updates.
|
||||
|
||||
"""
|
@@ -1,43 +0,0 @@
|
||||
"""ProtoTorch KNN model."""
|
||||
|
||||
import warnings
|
||||
|
||||
from ..core.competitions import KNNC
|
||||
from ..core.components import LabeledComponents
|
||||
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
|
||||
from ..utils.utils import parse_data_arg
|
||||
from .abstract import SupervisedPrototypeModel
|
||||
|
||||
|
||||
class KNN(SupervisedPrototypeModel):
|
||||
"""K-Nearest-Neighbors classification algorithm."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("k", 1)
|
||||
|
||||
data = kwargs.get("data", None)
|
||||
if data is None:
|
||||
raise ValueError("KNN requires data, but was not provided!")
|
||||
data, targets = parse_data_arg(data)
|
||||
|
||||
# Layers
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=[],
|
||||
components_initializer=LiteralCompInitializer(data),
|
||||
labels_initializer=LiteralLabelsInitializer(targets))
|
||||
self.competition_layer = KNNC(k=self.hparams.k)
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
return 1 # skip training step
|
||||
|
||||
def on_train_batch_start(self,
|
||||
train_batch,
|
||||
batch_idx,
|
||||
dataloader_idx=None):
|
||||
warnings.warn("k-NN has no training, skipping!")
|
||||
return -1
|
||||
|
||||
def configure_optimizers(self):
|
||||
return None
|
7
prototorch/models/library/__init__.py
Normal file
7
prototorch/models/library/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .glvq import GLVQ
|
||||
from .gmlvq import GMLVQ
|
||||
|
||||
__all__ = [
|
||||
"GLVQ",
|
||||
"GMLVQ",
|
||||
]
|
35
prototorch/models/library/glvq.py
Normal file
35
prototorch/models/library/glvq.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from prototorch.models import (
|
||||
SimpleComparisonMixin,
|
||||
SingleLearningRateMixin,
|
||||
SupervisedArchitecture,
|
||||
WTACompetitionMixin,
|
||||
)
|
||||
from prototorch.models.architectures.loss import GLVQLossMixin
|
||||
|
||||
|
||||
class GLVQ(
|
||||
SupervisedArchitecture,
|
||||
SimpleComparisonMixin,
|
||||
GLVQLossMixin,
|
||||
WTACompetitionMixin,
|
||||
SingleLearningRateMixin,
|
||||
):
|
||||
"""
|
||||
Generalized Learning Vector Quantization (GLVQ)
|
||||
|
||||
A GLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class HyperParameters(
|
||||
SimpleComparisonMixin.HyperParameters,
|
||||
SingleLearningRateMixin.HyperParameters,
|
||||
GLVQLossMixin.HyperParameters,
|
||||
WTACompetitionMixin.HyperParameters,
|
||||
SupervisedArchitecture.HyperParameters,
|
||||
):
|
||||
"""
|
||||
No hyperparameters.
|
||||
"""
|
50
prototorch/models/library/gmlvq.py
Normal file
50
prototorch/models/library/gmlvq.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from prototorch.core.distances import omega_distance
|
||||
from prototorch.models import (
|
||||
GLVQLossMixin,
|
||||
MultipleLearningRateMixin,
|
||||
OmegaComparisonMixin,
|
||||
SupervisedArchitecture,
|
||||
WTACompetitionMixin,
|
||||
)
|
||||
|
||||
|
||||
class GMLVQ(
|
||||
SupervisedArchitecture,
|
||||
OmegaComparisonMixin,
|
||||
GLVQLossMixin,
|
||||
WTACompetitionMixin,
|
||||
MultipleLearningRateMixin,
|
||||
):
|
||||
"""
|
||||
Generalized Matrix Learning Vector Quantization (GMLVQ)
|
||||
|
||||
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
|
||||
"""
|
||||
# HyperParameters
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class HyperParameters(
|
||||
MultipleLearningRateMixin.HyperParameters,
|
||||
OmegaComparisonMixin.HyperParameters,
|
||||
GLVQLossMixin.HyperParameters,
|
||||
WTACompetitionMixin.HyperParameters,
|
||||
SupervisedArchitecture.HyperParameters,
|
||||
):
|
||||
"""
|
||||
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
|
||||
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
||||
"""
|
||||
comparison_fn: Callable = omega_distance
|
||||
comparison_args: dict = field(default_factory=dict)
|
||||
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
lr: dict = field(default_factory=lambda: dict(
|
||||
components_layer=0.1,
|
||||
_omega=0.5,
|
||||
))
|
@@ -1,124 +0,0 @@
|
||||
"""LVQ models that are optimized using non-gradient methods."""
|
||||
|
||||
from ..core.losses import _get_dp_dm
|
||||
from ..nn.activations import get_activation
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
from .abstract import NonGradientMixin
|
||||
from .glvq import GLVQ
|
||||
|
||||
|
||||
class LVQ1(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 1."""
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
protos, plables = self.proto_layer()
|
||||
x, y = train_batch
|
||||
dis = self.compute_distances(x)
|
||||
# TODO Vectorized implementation
|
||||
|
||||
for xi, yi in zip(x, y):
|
||||
d = self.compute_distances(xi.view(1, -1))
|
||||
preds = self.competition_layer(d, plabels)
|
||||
w = d.argmin(1)
|
||||
if yi == preds:
|
||||
shift = xi - protos[w]
|
||||
else:
|
||||
shift = protos[w] - xi
|
||||
updated_protos = protos + 0.0
|
||||
updated_protos[w] = protos[w] + (self.hparams.lr * shift)
|
||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||
strict=False)
|
||||
|
||||
print(f"dis={dis}")
|
||||
print(f"y={y}")
|
||||
# Logging
|
||||
self.log_acc(dis, y, tag="train_acc")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class LVQ21(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 2.1."""
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
protos, plabels = self.proto_layer()
|
||||
|
||||
x, y = train_batch
|
||||
dis = self.compute_distances(x)
|
||||
# TODO Vectorized implementation
|
||||
|
||||
for xi, yi in zip(x, y):
|
||||
xi = xi.view(1, -1)
|
||||
yi = yi.view(1, )
|
||||
d = self.compute_distances(xi)
|
||||
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
|
||||
shiftp = xi - protos[wp]
|
||||
shiftn = protos[wn] - xi
|
||||
updated_protos = protos + 0.0
|
||||
updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp)
|
||||
updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn)
|
||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||
strict=False)
|
||||
|
||||
# Logging
|
||||
self.log_acc(dis, y, tag="train_acc")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MedianLVQ(NonGradientMixin, GLVQ):
|
||||
"""Median LVQ
|
||||
|
||||
# TODO Avoid computing distances over and over
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, verbose=True, **kwargs):
|
||||
self.verbose = verbose
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
self.transfer_layer = LambdaLayer(
|
||||
get_activation(self.hparams.transfer_fn))
|
||||
|
||||
def _f(self, x, y, protos, plabels):
|
||||
d = self.distance_layer(x, protos)
|
||||
dp, dm = _get_dp_dm(d, y, plabels)
|
||||
mu = (dp - dm) / (dp + dm)
|
||||
invmu = -1.0 * mu
|
||||
f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0
|
||||
return f
|
||||
|
||||
def expectation(self, x, y, protos, plabels):
|
||||
f = self._f(x, y, protos, plabels)
|
||||
gamma = f / f.sum()
|
||||
return gamma
|
||||
|
||||
def lower_bound(self, x, y, protos, plabels, gamma):
|
||||
f = self._f(x, y, protos, plabels)
|
||||
lower_bound = (gamma * f.log()).sum()
|
||||
return lower_bound
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
protos, plabels = self.proto_layer()
|
||||
|
||||
x, y = train_batch
|
||||
dis = self.compute_distances(x)
|
||||
|
||||
for i, _ in enumerate(protos):
|
||||
# Expectation step
|
||||
gamma = self.expectation(x, y, protos, plabels)
|
||||
lower_bound = self.lower_bound(x, y, protos, plabels, gamma)
|
||||
|
||||
# Maximization step
|
||||
_protos = protos + 0
|
||||
for k, xk in enumerate(x):
|
||||
_protos[i] = xk
|
||||
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
|
||||
if _lower_bound > lower_bound:
|
||||
if self.verbose:
|
||||
print(f"Updating prototype {i} to data {k}...")
|
||||
self.proto_layer.load_state_dict({"_components": _protos},
|
||||
strict=False)
|
||||
break
|
||||
|
||||
# Logging
|
||||
self.log_acc(dis, y, tag="train_acc")
|
||||
|
||||
return None
|
@@ -1,96 +0,0 @@
|
||||
"""Probabilistic GLVQ methods"""
|
||||
|
||||
import torch
|
||||
|
||||
from ..core.losses import nllr_loss, rslvq_loss
|
||||
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||
from .extras import GaussianPrior, RankScaledGaussianPrior
|
||||
from .glvq import GLVQ, SiameseGMLVQ
|
||||
|
||||
|
||||
class CELVQ(GLVQ):
|
||||
"""Cross-Entropy Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Loss
|
||||
self.loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self.compute_distances(x) # [None, num_protos]
|
||||
_, plabels = self.proto_layer()
|
||||
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
|
||||
probs = -1.0 * winning
|
||||
batch_loss = self.loss(probs, y.long())
|
||||
loss = batch_loss.sum()
|
||||
return out, loss
|
||||
|
||||
|
||||
class ProbabilisticLVQ(GLVQ):
|
||||
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
self.conditional_distribution = None
|
||||
self.rejection_confidence = rejection_confidence
|
||||
|
||||
def forward(self, x):
|
||||
distances = self.compute_distances(x)
|
||||
conditional = self.conditional_distribution(distances)
|
||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
||||
device=self.device)
|
||||
posterior = conditional * prior
|
||||
plabels = self.proto_layer._labels
|
||||
y_pred = stratified_sum_pooling(posterior, plabels)
|
||||
return y_pred
|
||||
|
||||
def predict(self, x):
|
||||
y_pred = self.forward(x)
|
||||
confidence, prediction = torch.max(y_pred, dim=1)
|
||||
prediction[confidence < self.rejection_confidence] = -1
|
||||
return prediction
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self.forward(x)
|
||||
_, plabels = self.proto_layer()
|
||||
batch_loss = self.loss(out, y, plabels)
|
||||
loss = batch_loss.sum()
|
||||
return loss
|
||||
|
||||
|
||||
class SLVQ(ProbabilisticLVQ):
|
||||
"""Soft Learning Vector Quantization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss = LossLayer(nllr_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class RSLVQ(ProbabilisticLVQ):
|
||||
"""Robust Soft Learning Vector Quantization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss = LossLayer(rslvq_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
"""Probabilistic Learning Vector Quantization.
|
||||
|
||||
TODO: Use Backbone LVQ instead
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conditional_distribution = RankScaledGaussianPrior(
|
||||
self.hparams.lambd)
|
||||
self.loss = torch.nn.KLDivLoss()
|
||||
|
||||
# FIXME
|
||||
# def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
# x, y = batch
|
||||
# y_pred = self(x)
|
||||
# batch_loss = self.loss(y_pred, y)
|
||||
# loss = batch_loss.sum()
|
||||
# return loss
|
@@ -1,146 +0,0 @@
|
||||
"""Unsupervised prototype learning algorithms."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..core.competitions import wtac
|
||||
from ..core.distances import squared_euclidean_distance
|
||||
from ..core.losses import NeuralGasEnergy
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
||||
from .callbacks import GNGCallback
|
||||
from .extras import ConnectionTopology
|
||||
|
||||
|
||||
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
"""Kohonen Self-Organizing-Map.
|
||||
|
||||
TODO Allow non-2D grids
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
h, w = hparams.get("shape")
|
||||
# Ignore `num_prototypes`
|
||||
hparams["num_prototypes"] = h * w
|
||||
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
# Hyperparameters
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("alpha", 0.3)
|
||||
self.hparams.setdefault("sigma", max(h, w) / 2.0)
|
||||
|
||||
# Additional parameters
|
||||
x, y = torch.arange(h), torch.arange(w)
|
||||
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
|
||||
self.register_buffer("_grid", grid)
|
||||
self._sigma = self.hparams.sigma
|
||||
self._lr = self.hparams.lr
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
grid = self._grid.view(-1, 2)
|
||||
wp = wtac(distances, grid)
|
||||
return wp
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
# x = train_batch
|
||||
# TODO Check if the batch has labels
|
||||
x = train_batch[0]
|
||||
d = self.compute_distances(x)
|
||||
wp = self.predict_from_distances(d)
|
||||
grid = self._grid.view(-1, 2)
|
||||
gd = squared_euclidean_distance(wp, grid)
|
||||
nh = torch.exp(-gd / self._sigma**2)
|
||||
protos = self.proto_layer()
|
||||
diff = x.unsqueeze(dim=1) - protos
|
||||
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
||||
updated_protos = protos + delta.sum(dim=0)
|
||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||
strict=False)
|
||||
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
self._sigma = self.hparams.sigma * np.exp(
|
||||
-self.current_epoch / self.trainer.max_epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(grid): (shape: {tuple(self._grid.shape)})"
|
||||
|
||||
|
||||
class HeskesSOM(UnsupervisedPrototypeModel):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
# TODO Implement me!
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NeuralGas(UnsupervisedPrototypeModel):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Hyperparameters
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("agelimit", 10)
|
||||
self.hparams.setdefault("lm", 1)
|
||||
|
||||
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
|
||||
self.topology_layer = ConnectionTopology(
|
||||
agelimit=self.hparams.agelimit,
|
||||
num_prototypes=self.hparams.num_prototypes,
|
||||
)
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
# x = train_batch
|
||||
# TODO Check if the batch has labels
|
||||
x = train_batch[0]
|
||||
d = self.compute_distances(x)
|
||||
loss, _ = self.energy_layer(d)
|
||||
self.topology_layer(d)
|
||||
self.log("loss", loss)
|
||||
return loss
|
||||
|
||||
# def training_epoch_end(self, training_step_outputs):
|
||||
# print(f"{self.trainer.lr_schedulers}")
|
||||
# print(f"{self.trainer.lr_schedulers[0]['scheduler'].optimizer}")
|
||||
|
||||
|
||||
class GrowingNeuralGas(NeuralGas):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Defaults
|
||||
self.hparams.setdefault("step_reduction", 0.5)
|
||||
self.hparams.setdefault("insert_reduction", 0.1)
|
||||
self.hparams.setdefault("insert_freq", 10)
|
||||
|
||||
errors = torch.zeros(self.hparams.num_prototypes, device=self.device)
|
||||
self.register_buffer("errors", errors)
|
||||
|
||||
def training_step(self, train_batch, _batch_idx):
|
||||
# x = train_batch
|
||||
# TODO Check if the batch has labels
|
||||
x = train_batch[0]
|
||||
d = self.compute_distances(x)
|
||||
loss, order = self.energy_layer(d)
|
||||
winner = order[:, 0]
|
||||
mask = torch.zeros_like(d)
|
||||
mask[torch.arange(len(mask)), winner] = 1.0
|
||||
dp = d * mask
|
||||
|
||||
self.errors += torch.sum(dp * dp)
|
||||
self.errors *= self.hparams.step_reduction
|
||||
|
||||
self.topology_layer(d)
|
||||
self.log("loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_callbacks(self):
|
||||
return [
|
||||
GNGCallback(reduction=self.hparams.insert_reduction,
|
||||
freq=self.hparams.insert_freq)
|
||||
]
|
@@ -1,20 +1,28 @@
|
||||
"""Visualization Callbacks."""
|
||||
|
||||
import warnings
|
||||
from typing import Sized
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from matplotlib import pyplot as plt
|
||||
from prototorch.utils.colors import get_colors, get_legend_handles
|
||||
from prototorch.utils.utils import mesh2d
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from ..utils.utils import mesh2d
|
||||
|
||||
|
||||
class Vis2DAbstract(pl.Callback):
|
||||
|
||||
def __init__(self,
|
||||
data,
|
||||
data=None,
|
||||
title="Prototype Visualization",
|
||||
cmap="viridis",
|
||||
xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2",
|
||||
legend_labels=None,
|
||||
border=0.1,
|
||||
resolution=100,
|
||||
flatten_data=True,
|
||||
@@ -27,24 +35,36 @@ class Vis2DAbstract(pl.Callback):
|
||||
block=False):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(data, Dataset):
|
||||
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
||||
elif isinstance(data, torch.utils.data.DataLoader):
|
||||
x = torch.tensor([])
|
||||
y = torch.tensor([])
|
||||
for x_b, y_b in data:
|
||||
x = torch.cat([x, x_b])
|
||||
y = torch.cat([y, y_b])
|
||||
if data:
|
||||
if isinstance(data, Dataset):
|
||||
if isinstance(data, Sized):
|
||||
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
||||
else:
|
||||
# TODO: Add support for non-sized datasets
|
||||
raise NotImplementedError(
|
||||
"Data must be a dataset with a __len__ method.")
|
||||
elif isinstance(data, DataLoader):
|
||||
x = torch.tensor([])
|
||||
y = torch.tensor([])
|
||||
for x_b, y_b in data:
|
||||
x = torch.cat([x, x_b])
|
||||
y = torch.cat([y, y_b])
|
||||
else:
|
||||
x, y = data
|
||||
|
||||
if flatten_data:
|
||||
x = x.reshape(len(x), -1)
|
||||
|
||||
self.x_train = x
|
||||
self.y_train = y
|
||||
else:
|
||||
x, y = data
|
||||
|
||||
if flatten_data:
|
||||
x = x.reshape(len(x), -1)
|
||||
|
||||
self.x_train = x
|
||||
self.y_train = y
|
||||
self.x_train = None
|
||||
self.y_train = None
|
||||
|
||||
self.title = title
|
||||
self.xlabel = xlabel
|
||||
self.ylabel = ylabel
|
||||
self.legend_labels = legend_labels
|
||||
self.fig = plt.figure(self.title)
|
||||
self.cmap = cmap
|
||||
self.border = border
|
||||
@@ -63,14 +83,12 @@ class Vis2DAbstract(pl.Callback):
|
||||
return False
|
||||
return True
|
||||
|
||||
def setup_ax(self, xlabel=None, ylabel=None):
|
||||
def setup_ax(self):
|
||||
ax = self.fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(self.title)
|
||||
if xlabel:
|
||||
ax.set_xlabel("Data dimension 1")
|
||||
if ylabel:
|
||||
ax.set_ylabel("Data dimension 2")
|
||||
ax.set_xlabel(self.xlabel)
|
||||
ax.set_ylabel(self.ylabel)
|
||||
if self.axis_off:
|
||||
ax.axis("off")
|
||||
return ax
|
||||
@@ -113,42 +131,47 @@ class Vis2DAbstract(pl.Callback):
|
||||
else:
|
||||
plt.show(block=self.block)
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
self.visualize(pl_module)
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
plt.close()
|
||||
|
||||
def visualize(self, pl_module):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VisGLVQ2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
ax = self.setup_ax()
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||
if x_train is not None:
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
|
||||
self.border, self.resolution)
|
||||
else:
|
||||
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
||||
_components = pl_module.proto_layer._components
|
||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||
y_pred = pl_module.predict(mesh_input)
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
|
||||
class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||
|
||||
def __init__(self, *args, map_protos=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.map_protos = map_protos
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
@@ -175,18 +198,42 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
class VisGMLVQ2D(Vis2DAbstract):
|
||||
|
||||
def __init__(self, *args, ev_proj=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ev_proj = ev_proj
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
device = pl_module.device
|
||||
omega = pl_module._omega.detach()
|
||||
lam = omega @ omega.T
|
||||
u, _, _ = torch.pca_lowrank(lam, q=2)
|
||||
with torch.no_grad():
|
||||
x_train = torch.Tensor(x_train).to(device)
|
||||
x_train = x_train @ u
|
||||
x_train = x_train.cpu().detach()
|
||||
if self.show_protos:
|
||||
with torch.no_grad():
|
||||
protos = torch.Tensor(protos).to(device)
|
||||
protos = protos @ u
|
||||
protos = protos.cpu().detach()
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
if self.show_protos:
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
|
||||
|
||||
class VisCBC2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
def visualize(self, pl_module):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
protos = pl_module.components
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, "w")
|
||||
x = np.vstack((x_train, protos))
|
||||
@@ -198,20 +245,15 @@ class VisCBC2D(Vis2DAbstract):
|
||||
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
|
||||
class VisNG2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
def visualize(self, pl_module):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
protos = pl_module.prototypes
|
||||
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
||||
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, "w")
|
||||
|
||||
@@ -225,10 +267,27 @@ class VisNG2D(Vis2DAbstract):
|
||||
"k-",
|
||||
)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
class VisSpectralProtos(Vis2DAbstract):
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
ax = self.setup_ax()
|
||||
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
|
||||
for p, pl in zip(protos, plabels):
|
||||
ax.plot(p, c=colors[int(pl)])
|
||||
if self.legend_labels:
|
||||
handles = get_legend_handles(
|
||||
colors,
|
||||
self.legend_labels,
|
||||
marker="lines",
|
||||
)
|
||||
ax.legend(handles=handles)
|
||||
|
||||
|
||||
class VisImgComp(Vis2DAbstract):
|
||||
|
||||
def __init__(self,
|
||||
*args,
|
||||
random_data=0,
|
||||
@@ -244,30 +303,45 @@ class VisImgComp(Vis2DAbstract):
|
||||
self.add_embedding = add_embedding
|
||||
self.embedding_data = embedding_data
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
tb = pl_module.logger.experiment
|
||||
if self.add_embedding:
|
||||
ind = np.random.choice(len(self.x_train),
|
||||
size=self.embedding_data,
|
||||
replace=False)
|
||||
data = self.x_train[ind]
|
||||
tb.add_embedding(data.view(len(ind), -1),
|
||||
label_img=data,
|
||||
global_step=None,
|
||||
tag="Data Embedding",
|
||||
metadata=self.y_train[ind],
|
||||
metadata_header=None)
|
||||
def on_train_start(self, _, pl_module):
|
||||
if isinstance(pl_module.logger, TensorBoardLogger):
|
||||
tb = pl_module.logger.experiment
|
||||
|
||||
if self.random_data:
|
||||
ind = np.random.choice(len(self.x_train),
|
||||
size=self.random_data,
|
||||
replace=False)
|
||||
data = self.x_train[ind]
|
||||
grid = torchvision.utils.make_grid(data, nrow=self.num_columns)
|
||||
tb.add_image(tag="Data",
|
||||
img_tensor=grid,
|
||||
global_step=None,
|
||||
dataformats=self.dataformats)
|
||||
# Add embedding
|
||||
if self.add_embedding:
|
||||
if self.x_train is not None and self.y_train is not None:
|
||||
ind = np.random.choice(len(self.x_train),
|
||||
size=self.embedding_data,
|
||||
replace=False)
|
||||
data = self.x_train[ind]
|
||||
tb.add_embedding(data.view(len(ind), -1),
|
||||
label_img=data,
|
||||
global_step=None,
|
||||
tag="Data Embedding",
|
||||
metadata=self.y_train[ind],
|
||||
metadata_header=None)
|
||||
else:
|
||||
raise ValueError("No data for add embedding flag")
|
||||
|
||||
# Random Data
|
||||
if self.random_data:
|
||||
if self.x_train is not None:
|
||||
ind = np.random.choice(len(self.x_train),
|
||||
size=self.random_data,
|
||||
replace=False)
|
||||
data = self.x_train[ind]
|
||||
grid = torchvision.utils.make_grid(data,
|
||||
nrow=self.num_columns)
|
||||
tb.add_image(tag="Data",
|
||||
img_tensor=grid,
|
||||
global_step=None,
|
||||
dataformats=self.dataformats)
|
||||
else:
|
||||
raise ValueError("No data for random data flag")
|
||||
|
||||
else:
|
||||
warnings.warn(
|
||||
f"TensorBoardLogger is required, got {type(pl_module.logger)}")
|
||||
|
||||
def add_to_tensorboard(self, trainer, pl_module):
|
||||
tb = pl_module.logger.experiment
|
||||
@@ -281,14 +355,9 @@ class VisImgComp(Vis2DAbstract):
|
||||
dataformats=self.dataformats,
|
||||
)
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
def visualize(self, pl_module):
|
||||
if self.show:
|
||||
components = pl_module.components
|
||||
grid = torchvision.utils.make_grid(components,
|
||||
nrow=self.num_columns)
|
||||
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
23
setup.cfg
23
setup.cfg
@@ -1,8 +1,23 @@
|
||||
[isort]
|
||||
profile = hug
|
||||
src_paths = isort, test
|
||||
|
||||
[yapf]
|
||||
based_on_style = pep8
|
||||
spaces_before_comment = 2
|
||||
split_before_logical_operator = true
|
||||
|
||||
[pylint]
|
||||
disable =
|
||||
too-many-arguments,
|
||||
too-few-public-methods,
|
||||
fixme,
|
||||
|
||||
|
||||
[pycodestyle]
|
||||
max-line-length = 79
|
||||
|
||||
[isort]
|
||||
profile = hug
|
||||
src_paths = isort, test
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = True
|
||||
force_grid_wrap = 3
|
||||
use_parentheses = True
|
||||
line_length = 79
|
||||
|
18
setup.py
18
setup.py
@@ -10,6 +10,8 @@
|
||||
|
||||
ProtoTorch models Plugin Package
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from pkg_resources import safe_name
|
||||
from setuptools import find_namespace_packages, setup
|
||||
|
||||
@@ -18,13 +20,13 @@ PLUGIN_NAME = "models"
|
||||
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
long_description = Path("README.md").read_text(encoding='utf8')
|
||||
|
||||
INSTALL_REQUIRES = [
|
||||
"prototorch>=0.7.0",
|
||||
"pytorch_lightning>=1.3.5",
|
||||
"prototorch>=0.7.3",
|
||||
"pytorch_lightning>=1.6.0",
|
||||
"torchmetrics",
|
||||
"protobuf<3.20.0",
|
||||
]
|
||||
CLI = [
|
||||
"jsonargparse",
|
||||
@@ -37,6 +39,7 @@ DOCS = [
|
||||
"recommonmark",
|
||||
"sphinx",
|
||||
"nbsphinx",
|
||||
"ipykernel",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib-katex",
|
||||
"sphinxcontrib-bibtex",
|
||||
@@ -53,7 +56,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
|
||||
|
||||
setup(
|
||||
name=safe_name("prototorch_" + PLUGIN_NAME),
|
||||
version="0.3.0",
|
||||
version="1.0.0-a8",
|
||||
description="Pre-packaged prototype-based "
|
||||
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
||||
long_description=long_description,
|
||||
@@ -63,7 +66,7 @@ setup(
|
||||
url=PROJECT_URL,
|
||||
download_url=DOWNLOAD_URL,
|
||||
license="MIT",
|
||||
python_requires=">=3.6",
|
||||
python_requires=">=3.7",
|
||||
install_requires=INSTALL_REQUIRES,
|
||||
extras_require={
|
||||
"dev": DEV,
|
||||
@@ -79,10 +82,11 @@ setup(
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Natural Language :: English",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
|
@@ -1,14 +0,0 @@
|
||||
"""prototorch.models test suite."""
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestDummy(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_dummy(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
13
tests/test_models.py
Normal file
13
tests/test_models.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""prototorch.models test suite."""
|
||||
|
||||
import prototorch as pt
|
||||
from prototorch.models.library import GLVQ
|
||||
|
||||
|
||||
def test_glvq_model_build():
|
||||
hparams = GLVQ.HyperParameters(
|
||||
distribution=dict(num_classes=2, per_class=1),
|
||||
component_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
model = GLVQ(hparams=hparams)
|
Reference in New Issue
Block a user