Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
aeb6417c28 | ||
|
cb7fb91c95 | ||
|
823b05e390 |
@ -1,13 +1,13 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.7.1
|
current_version = 0.2.0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
serialize = {major}.{minor}.{patch}
|
serialize = {major}.{minor}.{patch}
|
||||||
message = build: bump version {current_version} → {new_version}
|
message = build: bump version {current_version} → {new_version}
|
||||||
|
|
||||||
[bumpversion:file:pyproject.toml]
|
[bumpversion:file:setup.py]
|
||||||
|
|
||||||
[bumpversion:file:./src/prototorch/models/__init__.py]
|
[bumpversion:file:./prototorch/models/__init__.py]
|
||||||
|
|
||||||
[bumpversion:file:./docs/source/conf.py]
|
[bumpversion:file:./docs/source/conf.py]
|
||||||
|
15
.codacy.yml
Normal file
15
.codacy.yml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# To validate the contents of your configuration file
|
||||||
|
# run the following command in the folder where the configuration file is located:
|
||||||
|
# codacy-analysis-cli validate-configuration --directory `pwd`
|
||||||
|
# To analyse, run:
|
||||||
|
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
|
||||||
|
---
|
||||||
|
engines:
|
||||||
|
pylintpython3:
|
||||||
|
exclude_paths:
|
||||||
|
- config/engines.yml
|
||||||
|
remark-lint:
|
||||||
|
exclude_paths:
|
||||||
|
- config/engines.yml
|
||||||
|
exclude_paths:
|
||||||
|
- 'tests/**'
|
2
.codecov.yml
Normal file
2
.codecov.yml
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
comment:
|
||||||
|
require_changes: yes
|
25
.github/workflows/examples.yml
vendored
25
.github/workflows/examples.yml
vendored
@ -1,25 +0,0 @@
|
|||||||
# Thi workflow will install Python dependencies, run tests and lint with a single version of Python
|
|
||||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
|
||||||
|
|
||||||
name: examples
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
paths:
|
|
||||||
- "examples/**.py"
|
|
||||||
jobs:
|
|
||||||
cpu:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v3
|
|
||||||
- name: Set up Python 3.11
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: "3.11"
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install .[all]
|
|
||||||
- name: Run examples
|
|
||||||
run: |
|
|
||||||
./tests/test_examples.sh examples/
|
|
75
.github/workflows/pythonapp.yml
vendored
75
.github/workflows/pythonapp.yml
vendored
@ -1,75 +0,0 @@
|
|||||||
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
|
||||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
|
||||||
|
|
||||||
name: tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
pull_request:
|
|
||||||
branches: [master]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
style:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v3
|
|
||||||
- name: Set up Python 3.11
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: "3.11"
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install .[all]
|
|
||||||
- uses: pre-commit/action@v3.0.0
|
|
||||||
compatibility:
|
|
||||||
needs: style
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
|
||||||
os: [ubuntu-latest, windows-latest]
|
|
||||||
exclude:
|
|
||||||
- os: windows-latest
|
|
||||||
python-version: "3.8"
|
|
||||||
- os: windows-latest
|
|
||||||
python-version: "3.9"
|
|
||||||
- os: windows-latest
|
|
||||||
python-version: "3.10"
|
|
||||||
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install .[all]
|
|
||||||
- name: Test with pytest
|
|
||||||
run: |
|
|
||||||
pytest
|
|
||||||
publish_pypi:
|
|
||||||
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
|
|
||||||
needs: compatibility
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v3
|
|
||||||
- name: Set up Python 3.11
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: "3.11"
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install .[all]
|
|
||||||
pip install build
|
|
||||||
- name: Build package
|
|
||||||
run: python -m build . -C verbose
|
|
||||||
- name: Publish a Python distribution to PyPI
|
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
|
||||||
with:
|
|
||||||
user: __token__
|
|
||||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
|
@ -2,53 +2,52 @@
|
|||||||
# See https://pre-commit.com/hooks.html for more hooks
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.4.0
|
rev: v4.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
- id: check-ast
|
- id: check-ast
|
||||||
- id: check-case-conflict
|
- id: check-case-conflict
|
||||||
|
|
||||||
- repo: https://github.com/myint/autoflake
|
- repo: https://github.com/myint/autoflake
|
||||||
rev: v2.1.1
|
rev: v1.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: autoflake
|
- id: autoflake
|
||||||
|
|
||||||
- repo: http://github.com/PyCQA/isort
|
- repo: http://github.com/PyCQA/isort
|
||||||
rev: 5.12.0
|
rev: 5.8.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.3.0
|
rev: v0.902
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
files: prototorch
|
files: prototorch
|
||||||
additional_dependencies: [types-pkg_resources]
|
additional_dependencies: [types-pkg_resources]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||||
rev: v0.32.0
|
rev: v0.31.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
additional_dependencies: ["toml"]
|
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||||
rev: v1.10.0
|
rev: v1.9.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: python-use-type-annotations
|
- id: python-use-type-annotations
|
||||||
- id: python-no-log-warn
|
- id: python-no-log-warn
|
||||||
- id: python-check-blanket-noqa
|
- id: python-check-blanket-noqa
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v3.7.0
|
rev: v2.19.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
|
||||||
- repo: https://github.com/si-cim/gitlint
|
- repo: https://github.com/si-cim/gitlint
|
||||||
rev: v0.15.2-unofficial
|
rev: v0.15.2-unofficial
|
||||||
hooks:
|
hooks:
|
||||||
- id: gitlint
|
- id: gitlint
|
||||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
||||||
|
25
.travis.yml
Normal file
25
.travis.yml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
dist: bionic
|
||||||
|
sudo: false
|
||||||
|
language: python
|
||||||
|
python: 3.9
|
||||||
|
cache:
|
||||||
|
directories:
|
||||||
|
- "$HOME/.cache/pip"
|
||||||
|
- "./tests/artifacts"
|
||||||
|
- "$HOME/datasets"
|
||||||
|
install:
|
||||||
|
- pip install git+git://github.com/si-cim/prototorch@dev --progress-bar off
|
||||||
|
- pip install .[all] --progress-bar off
|
||||||
|
script:
|
||||||
|
- coverage run -m pytest
|
||||||
|
- ./tests/test_examples.sh examples/
|
||||||
|
after_success:
|
||||||
|
- bash <(curl -s https://codecov.io/bash)
|
||||||
|
deploy:
|
||||||
|
provider: pypi
|
||||||
|
username: __token__
|
||||||
|
password:
|
||||||
|
secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
|
||||||
|
on:
|
||||||
|
tags: true
|
||||||
|
skip_existing: true
|
@ -1,5 +1,6 @@
|
|||||||
# ProtoTorch Models
|
# ProtoTorch Models
|
||||||
|
|
||||||
|
[![Build Status](https://api.travis-ci.com/si-cim/prototorch_models.svg?branch=main)](https://travis-ci.com/github/si-cim/prototorch_models)
|
||||||
[![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch_models?color=yellow&label=version)](https://github.com/si-cim/prototorch_models/releases)
|
[![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch_models?color=yellow&label=version)](https://github.com/si-cim/prototorch_models/releases)
|
||||||
[![PyPI](https://img.shields.io/pypi/v/prototorch_models)](https://pypi.org/project/prototorch_models/)
|
[![PyPI](https://img.shields.io/pypi/v/prototorch_models)](https://pypi.org/project/prototorch_models/)
|
||||||
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch_models)](https://github.com/si-cim/prototorch_models/blob/master/LICENSE)
|
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch_models)](https://github.com/si-cim/prototorch_models/blob/master/LICENSE)
|
||||||
|
@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.7.1"
|
release = "0.2.0"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
@ -2,252 +2,223 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "7ac5eff0",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"# A short tutorial for the `prototorch.models` plugin"
|
"# A short tutorial for the `prototorch.models` plugin"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "beb83780",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"## Introduction"
|
"## Introduction"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "43b74278",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
|
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
|
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this."
|
"With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this."
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "4e5d1fad",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"## Basics"
|
"## Basics"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "1244b66b",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:"
|
"First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "dcb88e8a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import prototorch as pt\n",
|
"import prototorch as pt\n",
|
||||||
"import pytorch_lightning as pl\n",
|
"import pytorch_lightning as pl\n",
|
||||||
"import torch"
|
"import torch"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "1adbe2f8",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"### Building Models"
|
"### Building Models"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "96663ab1",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer."
|
"Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer."
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "819ba756",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"model = pt.models.GLVQ(\n",
|
"model = pt.models.GLVQ(\n",
|
||||||
" hparams=dict(distribution=[1, 1, 1]),\n",
|
" hparams=dict(distribution=[1, 1, 1]),\n",
|
||||||
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
|
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
|
||||||
")"
|
")"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "1b37e97c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"print(model)"
|
"print(model)"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "d2c86903",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n",
|
"The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros."
|
"The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros."
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "45806052",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"### Data"
|
"### Data"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "9d62c4c6",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets."
|
"The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets."
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "504df02c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"train_ds = pt.datasets.Iris(dims=[0, 2])"
|
"train_ds = pt.datasets.Iris(dims=[0, 2])"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "3b8e7756",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"type(train_ds)"
|
"type(train_ds)"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "bce43afa",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"train_ds.data.shape, train_ds.targets.shape"
|
"train_ds.data.shape, train_ds.targets.shape"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "26a83328",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly."
|
"Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly."
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "67b80fbe",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)"
|
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "c1185f31",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"type(train_loader)"
|
"type(train_loader)"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "9b5a8963",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"x_batch, y_batch = next(iter(train_loader))\n",
|
"x_batch, y_batch = next(iter(train_loader))\n",
|
||||||
"print(f\"{x_batch=}, {y_batch=}\")"
|
"print(f\"{x_batch=}, {y_batch=}\")"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "dd492ee2",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches."
|
"This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches."
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "5176b055",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"### Training"
|
"### Training"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "46a7a506",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented."
|
"If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented."
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "279e75b7",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"trainer = pl.Trainer(max_epochs=2, weights_summary=None)"
|
"trainer = pl.Trainer(max_epochs=2, weights_summary=None)"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "e496b492",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"trainer.fit(model, train_loader)"
|
"trainer.fit(model, train_loader)"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "497fbff6",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"### From data to a trained model - a very minimal example"
|
"### From data to a trained model - a very minimal example"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "ab069c5d",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"train_ds = pt.datasets.Iris(dims=[0, 2])\n",
|
"train_ds = pt.datasets.Iris(dims=[0, 2])\n",
|
||||||
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n",
|
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n",
|
||||||
@ -259,239 +230,49 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n",
|
"trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n",
|
||||||
"trainer.fit(model, train_loader)"
|
"trainer.fit(model, train_loader)"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "30c71a93",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Saving/Loading trained models"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "f74ed2c1",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Pytorch Lightning can automatically checkpoint the model during various stages of training, but it also possible to manually save a checkpoint after training."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "3156658d",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"ckpt_path = \"./checkpoints/glvq_iris.ckpt\"\n",
|
|
||||||
"trainer.save_checkpoint(ckpt_path)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "c1c34055",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"loaded_model = pt.models.GLVQ.load_from_checkpoint(ckpt_path, strict=False)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "bbbb08e9",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Visualizing decision boundaries in 2D"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "53ca52dc",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"pt.models.VisGLVQ2D(data=train_ds).visualize(loaded_model)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "8373531f",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Saving/Loading trained weights"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "937bc458",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"In most cases, the checkpointing workflow is sufficient. In some cases however, one might want to only save the trained weights from the model. The disadvantage of this method is that the model has be re-created using compatible initialization parameters before the weights could be loaded."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "1f2035af",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"ckpt_path = \"./checkpoints/glvq_iris_weights.pth\"\n",
|
|
||||||
"torch.save(model.state_dict(), ckpt_path)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "1206021a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = pt.models.GLVQ(\n",
|
|
||||||
" dict(distribution=(3, 2)),\n",
|
|
||||||
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "9f2a4beb",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"pt.models.VisGLVQ2D(data=train_ds, title=\"Before loading the weights\").visualize(model)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "528d2fc2",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"torch.load(ckpt_path)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "ec817e6b",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model.load_state_dict(torch.load(ckpt_path), strict=False)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "a208eab7",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"pt.models.VisGLVQ2D(data=train_ds, title=\"After loading the weights\").visualize(model)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "f8de748f",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"## Advanced"
|
"## Advanced"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "53a64063",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Warm-start a model with prototypes learned from another model"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "3177c277",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
|
|
||||||
"model = pt.models.SiameseGMLVQ(\n",
|
|
||||||
" dict(input_dim=2,\n",
|
|
||||||
" latent_dim=2,\n",
|
|
||||||
" distribution=(3, 2),\n",
|
|
||||||
" proto_lr=0.0001,\n",
|
|
||||||
" bb_lr=0.0001),\n",
|
|
||||||
" optimizer=torch.optim.Adam,\n",
|
|
||||||
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
|
|
||||||
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
|
|
||||||
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "8baee9a2",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"print(model)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "cc203088",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "1f6a33a5",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"### Initializing prototypes with a subset of a dataset (along with transformations)"
|
"### Initializing prototypes with a subset of a dataset (along with transformations)"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "946ce341",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import prototorch as pt\n",
|
"import prototorch as pt\n",
|
||||||
"import pytorch_lightning as pl\n",
|
"import pytorch_lightning as pl\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"from torchvision import transforms\n",
|
"from torchvision import transforms\n",
|
||||||
"from torchvision.datasets import MNIST\n",
|
"from torchvision.datasets import MNIST"
|
||||||
"from torchvision.utils import make_grid"
|
],
|
||||||
]
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "510d9bd4",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from matplotlib import pyplot as plt"
|
"from matplotlib import pyplot as plt"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "ea7c1228",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"train_ds = MNIST(\n",
|
"train_ds = MNIST(\n",
|
||||||
" \"~/datasets\",\n",
|
" \"~/datasets\",\n",
|
||||||
@ -503,87 +284,59 @@
|
|||||||
" transforms.ToTensor(),\n",
|
" transforms.ToTensor(),\n",
|
||||||
" ]),\n",
|
" ]),\n",
|
||||||
")"
|
")"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "1b9eaf5c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"s = int(0.05 * len(train_ds))\n",
|
"s = int(0.05 * len(train_ds))\n",
|
||||||
"init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])"
|
"init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "8c32c9f2",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"init_ds"
|
"init_ds"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "68a9a8b9",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"model = pt.models.ImageGLVQ(\n",
|
"model = pt.models.ImageGLVQ(\n",
|
||||||
" dict(distribution=(10, 1)),\n",
|
" dict(distribution=(10, 5)),\n",
|
||||||
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
|
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||||
")"
|
")"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "6f23df86",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"plt.imshow(model.get_prototype_grid(num_columns=5))"
|
"plt.imshow(model.get_prototype_grid(num_columns=10))"
|
||||||
]
|
],
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "1c23c7b2",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "30780927",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"protos, plabels = pt.components.LabeledComponents(\n",
|
|
||||||
" distribution=(10, 5),\n",
|
|
||||||
" components_initializer=pt.initializers.SMCI(init_ds),\n",
|
|
||||||
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
|
|
||||||
")()\n",
|
|
||||||
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "4fa69f92",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"## FAQs"
|
"## FAQs"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "fa20f9ac",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"### How do I Retrieve the prototypes and their respective labels from the model?\n",
|
"### How do I Retrieve the prototypes and their respective labels from the model?\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -598,12 +351,11 @@
|
|||||||
"```python\n",
|
"```python\n",
|
||||||
">>> model.prototype_labels\n",
|
">>> model.prototype_labels\n",
|
||||||
"```"
|
"```"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "ba8215bf",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
"source": [
|
||||||
"### How do I make inferences/predictions/recall with my trained model?\n",
|
"### How do I make inferences/predictions/recall with my trained model?\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -618,12 +370,13 @@
|
|||||||
"```python\n",
|
"```python\n",
|
||||||
">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n",
|
">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n",
|
||||||
"```"
|
"```"
|
||||||
]
|
],
|
||||||
|
"metadata": {}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -637,7 +390,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.12"
|
"version": "3.9.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
81
examples/binnam_tecator.py
Normal file
81
examples/binnam_tecator.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""Neural Additive Model (NAM) example for binary classification."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Tecator("~/datasets")
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(lr=0.1)
|
||||||
|
|
||||||
|
# Define the feature extractor
|
||||||
|
class FE(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.modules_list = torch.nn.ModuleList([
|
||||||
|
torch.nn.Linear(1, 3),
|
||||||
|
torch.nn.Sigmoid(),
|
||||||
|
torch.nn.Linear(3, 1),
|
||||||
|
torch.nn.Sigmoid(),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for m in self.modules_list:
|
||||||
|
x = m(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.BinaryNAM(
|
||||||
|
hparams,
|
||||||
|
extractors=torch.nn.ModuleList([FE() for _ in range(100)]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 100)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
es = pl.callbacks.EarlyStopping(
|
||||||
|
monitor="train_loss",
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=20,
|
||||||
|
mode="min",
|
||||||
|
verbose=True,
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[
|
||||||
|
es,
|
||||||
|
],
|
||||||
|
terminate_on_nan=True,
|
||||||
|
weights_summary=None,
|
||||||
|
accelerator="ddp",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
# Visualize extractor shape functions
|
||||||
|
fig, axes = plt.subplots(10, 10)
|
||||||
|
for i, ax in enumerate(axes.flat):
|
||||||
|
x = torch.linspace(-2, 2, 100) # TODO use min/max from data
|
||||||
|
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
|
||||||
|
ax.plot(x, y)
|
||||||
|
ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[])
|
||||||
|
plt.show()
|
86
examples/binnam_xor.py
Normal file
86
examples/binnam_xor.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
"""Neural Additive Model (NAM) example for binary classification."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.XOR()
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(lr=0.001)
|
||||||
|
|
||||||
|
# Define the feature extractor
|
||||||
|
class FE(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size=10):
|
||||||
|
super().__init__()
|
||||||
|
self.modules_list = torch.nn.ModuleList([
|
||||||
|
torch.nn.Linear(1, hidden_size),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(hidden_size, 1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for m in self.modules_list:
|
||||||
|
x = m(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.BinaryNAM(
|
||||||
|
hparams,
|
||||||
|
extractors=torch.nn.ModuleList([FE(20) for _ in range(2)]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.Vis2D(data=train_ds)
|
||||||
|
es = pl.callbacks.EarlyStopping(
|
||||||
|
monitor="train_loss",
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=50,
|
||||||
|
mode="min",
|
||||||
|
verbose=False,
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
es,
|
||||||
|
],
|
||||||
|
terminate_on_nan=True,
|
||||||
|
weights_summary="full",
|
||||||
|
accelerator="ddp",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
# Visualize extractor shape functions
|
||||||
|
fig, axes = plt.subplots(2)
|
||||||
|
for i, ax in enumerate(axes.flat):
|
||||||
|
x = torch.linspace(0, 1, 100) # TODO use min/max from data
|
||||||
|
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
|
||||||
|
ax.plot(x, y)
|
||||||
|
ax.set(title=f"Feature {i + 1}")
|
||||||
|
plt.show()
|
@ -1,32 +1,25 @@
|
|||||||
"""CBC example using the Iris dataset."""
|
"""CBC example using the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
import torch
|
||||||
from prototorch.models import CBC, VisCBC2D
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
pl.utilities.seed.seed_everything(seed=42)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, batch_size=32)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
@ -37,32 +30,23 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = CBC(
|
model = pt.models.CBC(
|
||||||
hparams,
|
hparams,
|
||||||
components_initializer=pt.initializers.SSCI(train_ds, noise=0.1),
|
components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
|
||||||
reasonings_initializer=pt.initializers.
|
reasonings_iniitializer=pt.initializers.
|
||||||
PurePositiveReasoningsInitializer(),
|
PurePositiveReasoningsInitializer(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisCBC2D(
|
vis = pt.models.VisCBC2D(data=train_ds,
|
||||||
data=train_ds,
|
title="CBC Iris Example",
|
||||||
title="CBC Iris Example",
|
resolution=100,
|
||||||
resolution=100,
|
axis_off=True)
|
||||||
axis_off=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
callbacks=[vis],
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
],
|
|
||||||
detect_anomaly=True,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
max_epochs=1000,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
8
examples/cli/README.md
Normal file
8
examples/cli/README.md
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Examples using Lightning CLI
|
||||||
|
|
||||||
|
Examples in this folder use the experimental [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html).
|
||||||
|
|
||||||
|
To use the example run
|
||||||
|
```
|
||||||
|
python gmlvq.py --config gmlvq.yaml
|
||||||
|
```
|
19
examples/cli/gmlvq.py
Normal file
19
examples/cli/gmlvq.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
"""GMLVQ example using the MNIST dataset."""
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import torch
|
||||||
|
from prototorch.models import ImageGMLVQ
|
||||||
|
from prototorch.models.abstract import PrototypeModel
|
||||||
|
from prototorch.models.data import MNISTDataModule
|
||||||
|
from pytorch_lightning.utilities.cli import LightningCLI
|
||||||
|
|
||||||
|
|
||||||
|
class ExperimentClass(ImageGMLVQ):
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams,
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
prototype_initializer=pt.components.zeros(28 * 28),
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)
|
11
examples/cli/gmlvq.yaml
Normal file
11
examples/cli/gmlvq.yaml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
model:
|
||||||
|
hparams:
|
||||||
|
input_dim: 784
|
||||||
|
latent_dim: 784
|
||||||
|
distribution:
|
||||||
|
num_classes: 10
|
||||||
|
prototypes_per_class: 2
|
||||||
|
proto_lr: 0.01
|
||||||
|
bb_lr: 0.01
|
||||||
|
data:
|
||||||
|
batch_size: 32
|
@ -1,50 +1,30 @@
|
|||||||
"""Dynamically prune 'loser' prototypes in GLVQ-type models."""
|
"""Dynamically prune 'loser' prototypes in GLVQ-type models."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import (
|
|
||||||
CELVQ,
|
|
||||||
PruneLoserPrototypes,
|
|
||||||
VisGLVQ2D,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
num_classes = 4
|
num_classes = 4
|
||||||
num_features = 2
|
num_features = 2
|
||||||
num_clusters = 1
|
num_clusters = 1
|
||||||
train_ds = pt.datasets.Random(
|
train_ds = pt.datasets.Random(num_samples=500,
|
||||||
num_samples=500,
|
num_classes=num_classes,
|
||||||
num_classes=num_classes,
|
num_features=num_features,
|
||||||
num_features=num_features,
|
num_clusters=num_clusters,
|
||||||
num_clusters=num_clusters,
|
separation=3.0,
|
||||||
separation=3.0,
|
seed=42)
|
||||||
seed=42,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, batch_size=256)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
prototypes_per_class = num_clusters * 5
|
prototypes_per_class = num_clusters * 5
|
||||||
@ -54,7 +34,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = CELVQ(
|
model = pt.models.CELVQ(
|
||||||
hparams,
|
hparams,
|
||||||
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
|
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
|
||||||
)
|
)
|
||||||
@ -63,18 +43,18 @@ if __name__ == "__main__":
|
|||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
logging.info(model)
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisGLVQ2D(train_ds)
|
vis = pt.models.VisGLVQ2D(train_ds)
|
||||||
pruning = PruneLoserPrototypes(
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
threshold=0.01, # prune prototype if it wins less than 1%
|
threshold=0.01, # prune prototype if it wins less than 1%
|
||||||
idle_epochs=20, # pruning too early may cause problems
|
idle_epochs=20, # pruning too early may cause problems
|
||||||
prune_quota_per_epoch=2, # prune at most 2 prototypes per epoch
|
prune_quota_per_epoch=2, # prune at most 2 prototypes per epoch
|
||||||
frequency=1, # prune every epoch
|
frequency=1, # prune every epoch
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
es = EarlyStopping(
|
es = pl.callbacks.EarlyStopping(
|
||||||
monitor="train_loss",
|
monitor="train_loss",
|
||||||
min_delta=0.001,
|
min_delta=0.001,
|
||||||
patience=20,
|
patience=20,
|
||||||
@ -84,18 +64,17 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
es,
|
es,
|
||||||
],
|
],
|
||||||
detect_anomaly=True,
|
progress_bar_refresh_rate=0,
|
||||||
log_every_n_steps=1,
|
terminate_on_nan=True,
|
||||||
max_epochs=1000,
|
weights_summary="full",
|
||||||
|
accelerator="ddp",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,35 +1,23 @@
|
|||||||
"""GLVQ example using the Iris dataset."""
|
"""GLVQ example using the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import GLVQ, VisGLVQ2D
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, batch_size=64, num_workers=4)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
@ -41,7 +29,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = GLVQ(
|
model = pt.models.GLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
@ -53,30 +41,15 @@ if __name__ == "__main__":
|
|||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisGLVQ2D(data=train_ds)
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
callbacks=[vis],
|
||||||
fast_dev_run=args.fast_dev_run,
|
weights_summary="full",
|
||||||
callbacks=[
|
accelerator="ddp",
|
||||||
vis,
|
|
||||||
],
|
|
||||||
max_epochs=100,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
# Manual save
|
|
||||||
trainer.save_checkpoint("./glvq_iris.ckpt")
|
|
||||||
|
|
||||||
# Load saved model
|
|
||||||
new_model = GLVQ.load_from_checkpoint(
|
|
||||||
checkpoint_path="./glvq_iris.ckpt",
|
|
||||||
strict=False,
|
|
||||||
)
|
|
||||||
logging.info(new_model)
|
|
||||||
|
@ -1,39 +1,22 @@
|
|||||||
"""GMLVQ example using the spiral dataset."""
|
"""GLVQ example using the spiral dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import (
|
|
||||||
GMLVQ,
|
|
||||||
PruneLoserPrototypes,
|
|
||||||
VisGLVQ2D,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
|
train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, batch_size=256)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
@ -49,19 +32,19 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = GMLVQ(
|
model = pt.models.GMLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
|
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisGLVQ2D(
|
vis = pt.models.VisGLVQ2D(
|
||||||
train_ds,
|
train_ds,
|
||||||
show_last_only=False,
|
show_last_only=False,
|
||||||
block=False,
|
block=False,
|
||||||
)
|
)
|
||||||
pruning = PruneLoserPrototypes(
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
threshold=0.01,
|
threshold=0.01,
|
||||||
idle_epochs=10,
|
idle_epochs=10,
|
||||||
prune_quota_per_epoch=5,
|
prune_quota_per_epoch=5,
|
||||||
@ -70,7 +53,7 @@ if __name__ == "__main__":
|
|||||||
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
|
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
es = EarlyStopping(
|
es = pl.callbacks.EarlyStopping(
|
||||||
monitor="train_loss",
|
monitor="train_loss",
|
||||||
min_delta=1.0,
|
min_delta=1.0,
|
||||||
patience=5,
|
patience=5,
|
||||||
@ -79,18 +62,14 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
es,
|
es,
|
||||||
pruning,
|
pruning,
|
||||||
],
|
],
|
||||||
max_epochs=1000,
|
terminate_on_nan=True,
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
@ -1,78 +0,0 @@
|
|||||||
"""GMLVQ example using the Iris dataset."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import GMLVQ, VisGMLVQ2D
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
|
|
||||||
# Command-line arguments
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Dataset
|
|
||||||
train_ds = pt.datasets.Iris()
|
|
||||||
|
|
||||||
# Dataloaders
|
|
||||||
train_loader = DataLoader(train_ds, batch_size=64)
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
hparams = dict(
|
|
||||||
input_dim=4,
|
|
||||||
latent_dim=4,
|
|
||||||
distribution={
|
|
||||||
"num_classes": 3,
|
|
||||||
"per_class": 2
|
|
||||||
},
|
|
||||||
proto_lr=0.01,
|
|
||||||
bb_lr=0.01,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the model
|
|
||||||
model = GMLVQ(
|
|
||||||
hparams,
|
|
||||||
optimizer=torch.optim.Adam,
|
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
|
||||||
lr_scheduler=ExponentialLR,
|
|
||||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
|
||||||
model.example_input_array = torch.zeros(4, 4)
|
|
||||||
|
|
||||||
# Callbacks
|
|
||||||
vis = VisGMLVQ2D(data=train_ds)
|
|
||||||
|
|
||||||
# Setup trainer
|
|
||||||
trainer = pl.Trainer(
|
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
],
|
|
||||||
max_epochs=100,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
trainer.fit(model, train_loader)
|
|
||||||
|
|
||||||
torch.save(model, "iris.pth")
|
|
@ -1,33 +1,17 @@
|
|||||||
"""GMLVQ example using the MNIST dataset."""
|
"""GMLVQ example using the MNIST dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import (
|
|
||||||
ImageGMLVQ,
|
|
||||||
PruneLoserPrototypes,
|
|
||||||
VisImgComp,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -49,8 +33,12 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, num_workers=4, batch_size=256)
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
test_loader = DataLoader(test_ds, num_workers=4, batch_size=256)
|
num_workers=0,
|
||||||
|
batch_size=256)
|
||||||
|
test_loader = torch.utils.data.DataLoader(test_ds,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=256)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
num_classes = 10
|
num_classes = 10
|
||||||
@ -64,14 +52,14 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = ImageGMLVQ(
|
model = pt.models.ImageGMLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisImgComp(
|
vis = pt.models.VisImgComp(
|
||||||
data=train_ds,
|
data=train_ds,
|
||||||
num_columns=10,
|
num_columns=10,
|
||||||
show=False,
|
show=False,
|
||||||
@ -81,14 +69,14 @@ if __name__ == "__main__":
|
|||||||
embedding_data=200,
|
embedding_data=200,
|
||||||
flatten_data=False,
|
flatten_data=False,
|
||||||
)
|
)
|
||||||
pruning = PruneLoserPrototypes(
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
threshold=0.01,
|
threshold=0.01,
|
||||||
idle_epochs=1,
|
idle_epochs=1,
|
||||||
prune_quota_per_epoch=10,
|
prune_quota_per_epoch=10,
|
||||||
frequency=1,
|
frequency=1,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
es = EarlyStopping(
|
es = pl.callbacks.EarlyStopping(
|
||||||
monitor="train_loss",
|
monitor="train_loss",
|
||||||
min_delta=0.001,
|
min_delta=0.001,
|
||||||
patience=15,
|
patience=15,
|
||||||
@ -97,18 +85,16 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
es,
|
# es,
|
||||||
],
|
],
|
||||||
max_epochs=1000,
|
terminate_on_nan=True,
|
||||||
log_every_n_steps=1,
|
weights_summary=None,
|
||||||
detect_anomaly=True,
|
# accelerator="ddp",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,33 +1,23 @@
|
|||||||
"""Growing Neural Gas example using the Iris dataset."""
|
"""Growing Neural Gas example using the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import GrowingNeuralGas, VisNG2D
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
seed_everything(seed=42)
|
pl.utilities.seed.seed_everything(seed=42)
|
||||||
|
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
train_loader = DataLoader(train_ds, batch_size=64)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
@ -37,7 +27,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = GrowingNeuralGas(
|
model = pt.models.GrowingNeuralGas(
|
||||||
hparams,
|
hparams,
|
||||||
prototypes_initializer=pt.initializers.ZCI(2),
|
prototypes_initializer=pt.initializers.ZCI(2),
|
||||||
)
|
)
|
||||||
@ -46,22 +36,17 @@ if __name__ == "__main__":
|
|||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
# Model summary
|
# Model summary
|
||||||
logging.info(model)
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisNG2D(data=train_loader)
|
vis = pt.models.VisNG2D(data=train_loader)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
],
|
|
||||||
max_epochs=100,
|
max_epochs=100,
|
||||||
log_every_n_steps=1,
|
callbacks=[vis],
|
||||||
detect_anomaly=True,
|
weights_summary="full",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,77 +0,0 @@
|
|||||||
"""GMLVQ example using the Iris dataset."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import GRLVQ, VisSiameseGLVQ2D
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
|
|
||||||
# Command-line arguments
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Dataset
|
|
||||||
train_ds = pt.datasets.Iris([0, 1])
|
|
||||||
|
|
||||||
# Dataloaders
|
|
||||||
train_loader = DataLoader(train_ds, batch_size=64)
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
hparams = dict(
|
|
||||||
input_dim=2,
|
|
||||||
distribution={
|
|
||||||
"num_classes": 3,
|
|
||||||
"per_class": 2
|
|
||||||
},
|
|
||||||
proto_lr=0.01,
|
|
||||||
bb_lr=0.01,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the model
|
|
||||||
model = GRLVQ(
|
|
||||||
hparams,
|
|
||||||
optimizer=torch.optim.Adam,
|
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
|
||||||
lr_scheduler=ExponentialLR,
|
|
||||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
|
||||||
model.example_input_array = torch.zeros(4, 2)
|
|
||||||
|
|
||||||
# Callbacks
|
|
||||||
vis = VisSiameseGLVQ2D(data=train_ds)
|
|
||||||
|
|
||||||
# Setup trainer
|
|
||||||
trainer = pl.Trainer(
|
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
],
|
|
||||||
max_epochs=5,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
trainer.fit(model, train_loader)
|
|
||||||
|
|
||||||
torch.save(model, "iris.pth")
|
|
@ -1,119 +0,0 @@
|
|||||||
"""GTLVQ example using the MNIST dataset."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import (
|
|
||||||
ImageGTLVQ,
|
|
||||||
PruneLoserPrototypes,
|
|
||||||
VisImgComp,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.datasets import MNIST
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
|
|
||||||
# Command-line arguments
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Dataset
|
|
||||||
train_ds = MNIST(
|
|
||||||
"~/datasets",
|
|
||||||
train=True,
|
|
||||||
download=True,
|
|
||||||
transform=transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
test_ds = MNIST(
|
|
||||||
"~/datasets",
|
|
||||||
train=False,
|
|
||||||
download=True,
|
|
||||||
transform=transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dataloaders
|
|
||||||
train_loader = DataLoader(train_ds, num_workers=0, batch_size=256)
|
|
||||||
test_loader = DataLoader(test_ds, num_workers=0, batch_size=256)
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
num_classes = 10
|
|
||||||
prototypes_per_class = 1
|
|
||||||
hparams = dict(
|
|
||||||
input_dim=28 * 28,
|
|
||||||
latent_dim=28,
|
|
||||||
distribution=(num_classes, prototypes_per_class),
|
|
||||||
proto_lr=0.01,
|
|
||||||
bb_lr=0.01,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the model
|
|
||||||
model = ImageGTLVQ(
|
|
||||||
hparams,
|
|
||||||
optimizer=torch.optim.Adam,
|
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
|
||||||
#Use one batch of data for subspace initiator.
|
|
||||||
omega_initializer=pt.initializers.PCALinearTransformInitializer(
|
|
||||||
next(iter(train_loader))[0].reshape(256, 28 * 28)))
|
|
||||||
|
|
||||||
# Callbacks
|
|
||||||
vis = VisImgComp(
|
|
||||||
data=train_ds,
|
|
||||||
num_columns=10,
|
|
||||||
show=False,
|
|
||||||
tensorboard=True,
|
|
||||||
random_data=100,
|
|
||||||
add_embedding=True,
|
|
||||||
embedding_data=200,
|
|
||||||
flatten_data=False,
|
|
||||||
)
|
|
||||||
pruning = PruneLoserPrototypes(
|
|
||||||
threshold=0.01,
|
|
||||||
idle_epochs=1,
|
|
||||||
prune_quota_per_epoch=10,
|
|
||||||
frequency=1,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
es = EarlyStopping(
|
|
||||||
monitor="train_loss",
|
|
||||||
min_delta=0.001,
|
|
||||||
patience=15,
|
|
||||||
mode="min",
|
|
||||||
check_on_train_epoch_end=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup trainer
|
|
||||||
# using GPUs here is strongly recommended!
|
|
||||||
trainer = pl.Trainer(
|
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
pruning,
|
|
||||||
es,
|
|
||||||
],
|
|
||||||
max_epochs=1000,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
trainer.fit(model, train_loader)
|
|
@ -1,79 +0,0 @@
|
|||||||
"""Localized-GTLVQ example using the Moons dataset."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import GTLVQ, VisGLVQ2D
|
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Command-line arguments
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=2)
|
|
||||||
|
|
||||||
# Dataset
|
|
||||||
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
|
||||||
|
|
||||||
# Dataloaders
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_ds,
|
|
||||||
batch_size=256,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
# Latent_dim should be lower than input dim.
|
|
||||||
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1)
|
|
||||||
|
|
||||||
# Initialize the model
|
|
||||||
model = GTLVQ(hparams,
|
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds))
|
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
|
||||||
model.example_input_array = torch.zeros(4, 2)
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
logging.info(model)
|
|
||||||
|
|
||||||
# Callbacks
|
|
||||||
vis = VisGLVQ2D(data=train_ds)
|
|
||||||
es = EarlyStopping(
|
|
||||||
monitor="train_acc",
|
|
||||||
min_delta=0.001,
|
|
||||||
patience=20,
|
|
||||||
mode="max",
|
|
||||||
verbose=False,
|
|
||||||
check_on_train_epoch_end=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup trainer
|
|
||||||
trainer = pl.Trainer(
|
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
es,
|
|
||||||
],
|
|
||||||
max_epochs=1000,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
trainer.fit(model, train_loader)
|
|
@ -1,75 +1,51 @@
|
|||||||
"""k-NN example using the Iris dataset from scikit-learn."""
|
"""k-NN example using the Iris dataset from scikit-learn."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from prototorch.models import KNN, VisGLVQ2D
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
X, y = load_iris(return_X_y=True)
|
x_train, y_train = load_iris(return_X_y=True)
|
||||||
X = X[:, 0:3:2]
|
x_train = x_train[:, [0, 2]]
|
||||||
|
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
|
||||||
X,
|
|
||||||
y,
|
|
||||||
test_size=0.5,
|
|
||||||
random_state=42,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_ds = pt.datasets.NumpyDataset(X_train, y_train)
|
|
||||||
test_ds = pt.datasets.NumpyDataset(X_test, y_test)
|
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, batch_size=16)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||||
test_loader = DataLoader(test_ds, batch_size=16)
|
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(k=5)
|
hparams = dict(k=5)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = KNN(hparams, data=train_ds)
|
model = pt.models.KNN(hparams, data=train_ds)
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
logging.info(model)
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisGLVQ2D(
|
vis = pt.models.VisGLVQ2D(
|
||||||
data=(X_train, y_train),
|
data=(x_train, y_train),
|
||||||
resolution=200,
|
resolution=200,
|
||||||
block=True,
|
block=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
max_epochs=1,
|
max_epochs=1,
|
||||||
callbacks=[
|
callbacks=[vis],
|
||||||
vis,
|
weights_summary="full",
|
||||||
],
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
@ -77,8 +53,5 @@ if __name__ == "__main__":
|
|||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
# Recall
|
# Recall
|
||||||
y_pred = model.predict(torch.tensor(X_train))
|
y_pred = model.predict(torch.tensor(x_train))
|
||||||
logging.info(y_pred)
|
print(y_pred)
|
||||||
|
|
||||||
# Test
|
|
||||||
trainer.test(model, dataloaders=test_loader)
|
|
||||||
|
@ -1,25 +1,15 @@
|
|||||||
"""Kohonen Self Organizing Map."""
|
"""Kohonen Self Organizing Map."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from prototorch.models import KohonenSOM
|
|
||||||
from prototorch.utils.colors import hex_to_rgb
|
from prototorch.utils.colors import hex_to_rgb
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
|
|
||||||
class Vis2DColorSOM(pl.Callback):
|
class Vis2DColorSOM(pl.Callback):
|
||||||
|
|
||||||
def __init__(self, data, title="ColorSOMe", pause_time=0.1):
|
def __init__(self, data, title="ColorSOMe", pause_time=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.title = title
|
self.title = title
|
||||||
@ -27,7 +17,7 @@ class Vis2DColorSOM(pl.Callback):
|
|||||||
self.data = data
|
self.data = data
|
||||||
self.pause_time = pause_time
|
self.pause_time = pause_time
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module: KohonenSOM):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
ax = self.fig.gca()
|
ax = self.fig.gca()
|
||||||
ax.cla()
|
ax.cla()
|
||||||
ax.set_title(self.title)
|
ax.set_title(self.title)
|
||||||
@ -40,14 +30,12 @@ class Vis2DColorSOM(pl.Callback):
|
|||||||
d = pl_module.compute_distances(self.data)
|
d = pl_module.compute_distances(self.data)
|
||||||
wp = pl_module.predict_from_distances(d)
|
wp = pl_module.predict_from_distances(d)
|
||||||
for i, iloc in enumerate(wp):
|
for i, iloc in enumerate(wp):
|
||||||
plt.text(
|
plt.text(iloc[1],
|
||||||
iloc[1],
|
iloc[0],
|
||||||
iloc[0],
|
cnames[i],
|
||||||
color_names[i],
|
ha="center",
|
||||||
ha="center",
|
va="center",
|
||||||
va="center",
|
bbox=dict(facecolor="white", alpha=0.5, lw=0))
|
||||||
bbox=dict(facecolor="white", alpha=0.5, lw=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
if trainer.current_epoch != trainer.max_epochs - 1:
|
if trainer.current_epoch != trainer.max_epochs - 1:
|
||||||
plt.pause(self.pause_time)
|
plt.pause(self.pause_time)
|
||||||
@ -58,12 +46,11 @@ class Vis2DColorSOM(pl.Callback):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
seed_everything(seed=42)
|
pl.utilities.seed.seed_everything(seed=42)
|
||||||
|
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
hex_colors = [
|
hex_colors = [
|
||||||
@ -71,15 +58,15 @@ if __name__ == "__main__":
|
|||||||
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
|
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
|
||||||
"#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
|
"#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
|
||||||
]
|
]
|
||||||
color_names = [
|
cnames = [
|
||||||
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
|
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
|
||||||
"red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
|
"red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
|
||||||
"lightgrey", "olive", "purple", "orange"
|
"lightgrey", "olive", "purple", "orange"
|
||||||
]
|
]
|
||||||
colors = list(hex_to_rgb(hex_colors))
|
colors = list(hex_to_rgb(hex_colors))
|
||||||
data = torch.Tensor(colors) / 255.0
|
data = torch.Tensor(colors) / 255.0
|
||||||
train_ds = TensorDataset(data)
|
train_ds = torch.utils.data.TensorDataset(data)
|
||||||
train_loader = DataLoader(train_ds, batch_size=8)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
@ -90,7 +77,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = KohonenSOM(
|
model = pt.models.KohonenSOM(
|
||||||
hparams,
|
hparams,
|
||||||
prototypes_initializer=pt.initializers.RNCI(3),
|
prototypes_initializer=pt.initializers.RNCI(3),
|
||||||
)
|
)
|
||||||
@ -99,22 +86,17 @@ if __name__ == "__main__":
|
|||||||
model.example_input_array = torch.zeros(4, 3)
|
model.example_input_array = torch.zeros(4, 3)
|
||||||
|
|
||||||
# Model summary
|
# Model summary
|
||||||
logging.info(model)
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = Vis2DColorSOM(data=data)
|
vis = Vis2DColorSOM(data=data)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
max_epochs=500,
|
max_epochs=500,
|
||||||
callbacks=[
|
callbacks=[vis],
|
||||||
vis,
|
weights_summary="full",
|
||||||
],
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,36 +1,27 @@
|
|||||||
"""Localized-GMLVQ example using the Moons dataset."""
|
"""Localized-GMLVQ example using the Moons dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import LGMLVQ, VisGLVQ2D
|
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
seed_everything(seed=2)
|
pl.utilities.seed.seed_everything(seed=2)
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
|
batch_size=256,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
@ -40,7 +31,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = LGMLVQ(
|
model = pt.models.LGMLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
)
|
)
|
||||||
@ -49,11 +40,11 @@ if __name__ == "__main__":
|
|||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
logging.info(model)
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisGLVQ2D(data=train_ds)
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
es = EarlyStopping(
|
es = pl.callbacks.EarlyStopping(
|
||||||
monitor="train_acc",
|
monitor="train_acc",
|
||||||
min_delta=0.001,
|
min_delta=0.001,
|
||||||
patience=20,
|
patience=20,
|
||||||
@ -63,17 +54,14 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
es,
|
es,
|
||||||
],
|
],
|
||||||
log_every_n_steps=1,
|
weights_summary="full",
|
||||||
max_epochs=1000,
|
accelerator="ddp",
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,26 +1,13 @@
|
|||||||
"""LVQMLN example using all four dimensions of the Iris dataset."""
|
"""LVQMLN example using all four dimensions of the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import (
|
|
||||||
LVQMLN,
|
|
||||||
PruneLoserPrototypes,
|
|
||||||
VisSiameseGLVQ2D,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
|
|
||||||
class Backbone(torch.nn.Module):
|
class Backbone(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
@ -39,18 +26,17 @@ class Backbone(torch.nn.Module):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Iris()
|
train_ds = pt.datasets.Iris()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
seed_everything(seed=42)
|
pl.utilities.seed.seed_everything(seed=42)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, batch_size=150)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
@ -63,7 +49,7 @@ if __name__ == "__main__":
|
|||||||
backbone = Backbone()
|
backbone = Backbone()
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = LVQMLN(
|
model = pt.models.LVQMLN(
|
||||||
hparams,
|
hparams,
|
||||||
prototypes_initializer=pt.initializers.SSCI(
|
prototypes_initializer=pt.initializers.SSCI(
|
||||||
train_ds,
|
train_ds,
|
||||||
@ -72,15 +58,18 @@ if __name__ == "__main__":
|
|||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Model summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisSiameseGLVQ2D(
|
vis = pt.models.VisSiameseGLVQ2D(
|
||||||
data=train_ds,
|
data=train_ds,
|
||||||
map_protos=False,
|
map_protos=False,
|
||||||
border=0.1,
|
border=0.1,
|
||||||
resolution=500,
|
resolution=500,
|
||||||
axis_off=True,
|
axis_off=True,
|
||||||
)
|
)
|
||||||
pruning = PruneLoserPrototypes(
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
threshold=0.01,
|
threshold=0.01,
|
||||||
idle_epochs=20,
|
idle_epochs=20,
|
||||||
prune_quota_per_epoch=2,
|
prune_quota_per_epoch=2,
|
||||||
@ -89,17 +78,12 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
],
|
],
|
||||||
log_every_n_steps=1,
|
|
||||||
max_epochs=1000,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,40 +1,28 @@
|
|||||||
"""Median-LVQ example using the Iris dataset."""
|
"""Median-LVQ example using the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import MedianLVQ, VisGLVQ2D
|
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
train_ds,
|
train_ds,
|
||||||
batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches
|
batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = MedianLVQ(
|
model = pt.models.MedianLVQ(
|
||||||
hparams=dict(distribution=(3, 2), lr=0.01),
|
hparams=dict(distribution=(3, 2), lr=0.01),
|
||||||
prototypes_initializer=pt.initializers.SSCI(train_ds),
|
prototypes_initializer=pt.initializers.SSCI(train_ds),
|
||||||
)
|
)
|
||||||
@ -43,8 +31,8 @@ if __name__ == "__main__":
|
|||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisGLVQ2D(data=train_ds)
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
es = EarlyStopping(
|
es = pl.callbacks.EarlyStopping(
|
||||||
monitor="train_acc",
|
monitor="train_acc",
|
||||||
min_delta=0.01,
|
min_delta=0.01,
|
||||||
patience=5,
|
patience=5,
|
||||||
@ -54,17 +42,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
callbacks=[vis, es],
|
||||||
fast_dev_run=args.fast_dev_run,
|
weights_summary="full",
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
es,
|
|
||||||
],
|
|
||||||
max_epochs=1000,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,35 +1,23 @@
|
|||||||
"""Neural Gas example using the Iris dataset."""
|
"""Neural Gas example using the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import NeuralGas, VisNG2D
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Prepare and pre-process the dataset
|
# Prepare and pre-process the dataset
|
||||||
x_train, y_train = load_iris(return_X_y=True)
|
x_train, y_train = load_iris(return_X_y=True)
|
||||||
x_train = x_train[:, 0:3:2]
|
x_train = x_train[:, [0, 2]]
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
scaler.fit(x_train)
|
scaler.fit(x_train)
|
||||||
x_train = scaler.transform(x_train)
|
x_train = scaler.transform(x_train)
|
||||||
@ -37,7 +25,7 @@ if __name__ == "__main__":
|
|||||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, batch_size=150)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
@ -47,7 +35,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = NeuralGas(
|
model = pt.models.NeuralGas(
|
||||||
hparams,
|
hparams,
|
||||||
prototypes_initializer=pt.core.ZCI(2),
|
prototypes_initializer=pt.core.ZCI(2),
|
||||||
lr_scheduler=ExponentialLR,
|
lr_scheduler=ExponentialLR,
|
||||||
@ -57,20 +45,17 @@ if __name__ == "__main__":
|
|||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Model summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisNG2D(data=train_ds)
|
vis = pt.models.VisNG2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
callbacks=[vis],
|
||||||
fast_dev_run=args.fast_dev_run,
|
weights_summary="full",
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
],
|
|
||||||
max_epochs=1000,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,34 +1,25 @@
|
|||||||
"""RSLVQ example using the Iris dataset."""
|
"""RSLVQ example using the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import RSLVQ, VisGLVQ2D
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
seed_everything(seed=42)
|
pl.utilities.seed.seed_everything(seed=42)
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, batch_size=64)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
@ -42,7 +33,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = RSLVQ(
|
model = pt.models.RSLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
|
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
|
||||||
@ -51,20 +42,19 @@ if __name__ == "__main__":
|
|||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisGLVQ2D(data=train_ds)
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
callbacks=[vis],
|
||||||
fast_dev_run=args.fast_dev_run,
|
terminate_on_nan=True,
|
||||||
callbacks=[
|
weights_summary="full",
|
||||||
vis,
|
accelerator="ddp",
|
||||||
],
|
|
||||||
detect_anomaly=True,
|
|
||||||
max_epochs=100,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,22 +1,13 @@
|
|||||||
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
|
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
|
|
||||||
class Backbone(torch.nn.Module):
|
class Backbone(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
@ -35,50 +26,46 @@ class Backbone(torch.nn.Module):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Iris()
|
train_ds = pt.datasets.Iris()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
seed_everything(seed=2)
|
pl.utilities.seed.seed_everything(seed=2)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, batch_size=150)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[1, 2, 3],
|
distribution=[1, 2, 3],
|
||||||
lr=0.01,
|
proto_lr=0.01,
|
||||||
|
bb_lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the backbone
|
# Initialize the backbone
|
||||||
backbone = Backbone()
|
backbone = Backbone()
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = SiameseGLVQ(
|
model = pt.models.SiameseGLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
both_path_gradients=False,
|
both_path_gradients=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Model summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
callbacks=[vis],
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
],
|
|
||||||
max_epochs=1000,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,87 +0,0 @@
|
|||||||
"""Siamese GTLVQ example using all four dimensions of the Iris dataset."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
|
|
||||||
|
|
||||||
class Backbone(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
|
||||||
super().__init__()
|
|
||||||
self.input_size = input_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.latent_size = latent_size
|
|
||||||
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
|
|
||||||
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
|
|
||||||
self.activation = torch.nn.Sigmoid()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.activation(self.dense1(x))
|
|
||||||
out = self.activation(self.dense2(x))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Command-line arguments
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Dataset
|
|
||||||
train_ds = pt.datasets.Iris()
|
|
||||||
|
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=2)
|
|
||||||
|
|
||||||
# Dataloaders
|
|
||||||
train_loader = DataLoader(train_ds, batch_size=150)
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
hparams = dict(
|
|
||||||
distribution=[1, 2, 3],
|
|
||||||
lr=0.01,
|
|
||||||
input_dim=2,
|
|
||||||
latent_dim=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the backbone
|
|
||||||
backbone = Backbone(latent_size=hparams["input_dim"])
|
|
||||||
|
|
||||||
# Initialize the model
|
|
||||||
model = SiameseGTLVQ(
|
|
||||||
hparams,
|
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
|
||||||
backbone=backbone,
|
|
||||||
both_path_gradients=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Callbacks
|
|
||||||
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
|
||||||
|
|
||||||
# Setup trainer
|
|
||||||
trainer = pl.Trainer(
|
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
],
|
|
||||||
max_epochs=1000,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
trainer.fit(model, train_loader)
|
|
@ -1,42 +1,24 @@
|
|||||||
"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
|
"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from lightning_fabric.utilities.seed import seed_everything
|
|
||||||
from prototorch.models import (
|
|
||||||
GLVQ,
|
|
||||||
KNN,
|
|
||||||
GrowingNeuralGas,
|
|
||||||
PruneLoserPrototypes,
|
|
||||||
VisGLVQ2D,
|
|
||||||
)
|
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Reproducibility
|
|
||||||
seed_everything(seed=4)
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--gpus", type=int, default=0)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
train_loader = DataLoader(train_ds, batch_size=64, num_workers=0)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||||
|
|
||||||
# Initialize the gng
|
# Initialize the gng
|
||||||
gng = GrowingNeuralGas(
|
gng = pt.models.GrowingNeuralGas(
|
||||||
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
|
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
|
||||||
prototypes_initializer=pt.initializers.ZCI(2),
|
prototypes_initializer=pt.initializers.ZCI(2),
|
||||||
lr_scheduler=ExponentialLR,
|
lr_scheduler=ExponentialLR,
|
||||||
@ -44,7 +26,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
es = EarlyStopping(
|
es = pl.callbacks.EarlyStopping(
|
||||||
monitor="loss",
|
monitor="loss",
|
||||||
min_delta=0.001,
|
min_delta=0.001,
|
||||||
patience=20,
|
patience=20,
|
||||||
@ -55,14 +37,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Setup trainer for GNG
|
# Setup trainer for GNG
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
accelerator="cpu",
|
max_epochs=100,
|
||||||
max_epochs=50 if args.fast_dev_run else
|
callbacks=[es],
|
||||||
1000, # 10 epochs fast dev run reproducible DIV error.
|
weights_summary=None,
|
||||||
callbacks=[
|
|
||||||
es,
|
|
||||||
],
|
|
||||||
log_every_n_steps=1,
|
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
@ -75,12 +52,12 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Warm-start prototypes
|
# Warm-start prototypes
|
||||||
knn = KNN(dict(k=1), data=train_ds)
|
knn = pt.models.KNN(dict(k=1), data=train_ds)
|
||||||
prototypes = gng.prototypes
|
prototypes = gng.prototypes
|
||||||
plabels = knn.predict(prototypes)
|
plabels = knn.predict(prototypes)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = GLVQ(
|
model = pt.models.GLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
prototypes_initializer=pt.initializers.LCI(prototypes),
|
prototypes_initializer=pt.initializers.LCI(prototypes),
|
||||||
@ -93,15 +70,15 @@ if __name__ == "__main__":
|
|||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = VisGLVQ2D(data=train_ds)
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
pruning = PruneLoserPrototypes(
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
threshold=0.02,
|
threshold=0.02,
|
||||||
idle_epochs=2,
|
idle_epochs=2,
|
||||||
prune_quota_per_epoch=5,
|
prune_quota_per_epoch=5,
|
||||||
frequency=1,
|
frequency=1,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
es = EarlyStopping(
|
es = pl.callbacks.EarlyStopping(
|
||||||
monitor="train_loss",
|
monitor="train_loss",
|
||||||
min_delta=0.001,
|
min_delta=0.001,
|
||||||
patience=10,
|
patience=10,
|
||||||
@ -111,18 +88,15 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
accelerator="cuda" if args.gpus else "cpu",
|
args,
|
||||||
devices=args.gpus if args.gpus else "auto",
|
|
||||||
fast_dev_run=args.fast_dev_run,
|
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
es,
|
es,
|
||||||
],
|
],
|
||||||
max_epochs=1000,
|
weights_summary="full",
|
||||||
log_every_n_steps=1,
|
accelerator="ddp",
|
||||||
detect_anomaly=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""`models` plugin for the `prototorch` package."""
|
"""`models` plugin for the `prototorch` package."""
|
||||||
|
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||||
from .cbc import CBC, ImageCBC
|
from .cbc import CBC, ImageCBC
|
||||||
from .glvq import (
|
from .glvq import (
|
||||||
@ -8,32 +10,18 @@ from .glvq import (
|
|||||||
GLVQ21,
|
GLVQ21,
|
||||||
GMLVQ,
|
GMLVQ,
|
||||||
GRLVQ,
|
GRLVQ,
|
||||||
GTLVQ,
|
|
||||||
LGMLVQ,
|
LGMLVQ,
|
||||||
LVQMLN,
|
LVQMLN,
|
||||||
ImageGLVQ,
|
ImageGLVQ,
|
||||||
ImageGMLVQ,
|
ImageGMLVQ,
|
||||||
ImageGTLVQ,
|
|
||||||
SiameseGLVQ,
|
SiameseGLVQ,
|
||||||
SiameseGMLVQ,
|
SiameseGMLVQ,
|
||||||
SiameseGTLVQ,
|
|
||||||
)
|
)
|
||||||
from .knn import KNN
|
from .knn import KNN
|
||||||
from .lvq import (
|
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||||
LVQ1,
|
from .nam import BinaryNAM
|
||||||
LVQ21,
|
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||||
MedianLVQ,
|
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
|
||||||
)
|
|
||||||
from .probabilistic import (
|
|
||||||
CELVQ,
|
|
||||||
RSLVQ,
|
|
||||||
SLVQ,
|
|
||||||
)
|
|
||||||
from .unsupervised import (
|
|
||||||
GrowingNeuralGas,
|
|
||||||
KohonenSOM,
|
|
||||||
NeuralGas,
|
|
||||||
)
|
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
__version__ = "0.7.1"
|
__version__ = "0.2.0"
|
@ -1,29 +1,21 @@
|
|||||||
"""Abstract classes to be inherited by prototorch models."""
|
"""Abstract classes to be inherited by prototorch models."""
|
||||||
|
|
||||||
import logging
|
from typing import Final, final
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.core.competitions import WTAC
|
|
||||||
from prototorch.core.components import (
|
from ..core.competitions import WTAC
|
||||||
AbstractComponents,
|
from ..core.components import Components, LabeledComponents
|
||||||
Components,
|
from ..core.distances import euclidean_distance
|
||||||
LabeledComponents,
|
from ..core.initializers import LabelsInitializer
|
||||||
)
|
from ..core.pooling import stratified_min_pooling
|
||||||
from prototorch.core.distances import euclidean_distance
|
from ..nn.wrappers import LambdaLayer
|
||||||
from prototorch.core.initializers import (
|
|
||||||
LabelsInitializer,
|
|
||||||
ZerosCompInitializer,
|
|
||||||
)
|
|
||||||
from prototorch.core.pooling import stratified_min_pooling
|
|
||||||
from prototorch.nn.wrappers import LambdaLayer
|
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchBolt(pl.LightningModule):
|
class ProtoTorchBolt(pl.LightningModule):
|
||||||
"""All ProtoTorch models are ProtoTorch Bolts."""
|
"""All ProtoTorch models are ProtoTorch Bolts."""
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -39,7 +31,7 @@ class ProtoTorchBolt(pl.LightningModule):
|
|||||||
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = self.optimizer(self.parameters(), lr=self.hparams["lr"])
|
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
||||||
if self.lr_scheduler is not None:
|
if self.lr_scheduler is not None:
|
||||||
scheduler = self.lr_scheduler(optimizer,
|
scheduler = self.lr_scheduler(optimizer,
|
||||||
**self.lr_scheduler_kwargs)
|
**self.lr_scheduler_kwargs)
|
||||||
@ -51,11 +43,9 @@ class ProtoTorchBolt(pl.LightningModule):
|
|||||||
else:
|
else:
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
@final
|
||||||
def reconfigure_optimizers(self):
|
def reconfigure_optimizers(self):
|
||||||
if self.trainer:
|
self.trainer.accelerator.setup_optimizers(self.trainer)
|
||||||
self.trainer.strategy.setup_optimizers(self.trainer)
|
|
||||||
else:
|
|
||||||
logging.warning("No trainer to reconfigure optimizers!")
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
surep = super().__repr__()
|
surep = super().__repr__()
|
||||||
@ -65,13 +55,11 @@ class ProtoTorchBolt(pl.LightningModule):
|
|||||||
|
|
||||||
|
|
||||||
class PrototypeModel(ProtoTorchBolt):
|
class PrototypeModel(ProtoTorchBolt):
|
||||||
proto_layer: AbstractComponents
|
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||||
self.distance_layer = LambdaLayer(distance_fn, name="distance_fn")
|
self.distance_layer = LambdaLayer(distance_fn)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_prototypes(self):
|
def num_prototypes(self):
|
||||||
@ -88,18 +76,14 @@ class PrototypeModel(ProtoTorchBolt):
|
|||||||
|
|
||||||
def add_prototypes(self, *args, **kwargs):
|
def add_prototypes(self, *args, **kwargs):
|
||||||
self.proto_layer.add_components(*args, **kwargs)
|
self.proto_layer.add_components(*args, **kwargs)
|
||||||
self.hparams["distribution"] = self.proto_layer.distribution
|
|
||||||
self.reconfigure_optimizers()
|
self.reconfigure_optimizers()
|
||||||
|
|
||||||
def remove_prototypes(self, indices):
|
def remove_prototypes(self, indices):
|
||||||
self.proto_layer.remove_components(indices)
|
self.proto_layer.remove_components(indices)
|
||||||
self.hparams["distribution"] = self.proto_layer.distribution
|
|
||||||
self.reconfigure_optimizers()
|
self.reconfigure_optimizers()
|
||||||
|
|
||||||
|
|
||||||
class UnsupervisedPrototypeModel(PrototypeModel):
|
class UnsupervisedPrototypeModel(PrototypeModel):
|
||||||
proto_layer: Components
|
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
@ -107,12 +91,12 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
|||||||
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||||
if prototypes_initializer is not None:
|
if prototypes_initializer is not None:
|
||||||
self.proto_layer = Components(
|
self.proto_layer = Components(
|
||||||
self.hparams["num_prototypes"],
|
self.hparams.num_prototypes,
|
||||||
initializer=prototypes_initializer,
|
initializer=prototypes_initializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos = self.proto_layer().type_as(x)
|
protos = self.proto_layer()
|
||||||
distances = self.distance_layer(x, protos)
|
distances = self.distance_layer(x, protos)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
@ -122,34 +106,19 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
|
|
||||||
class SupervisedPrototypeModel(PrototypeModel):
|
class SupervisedPrototypeModel(PrototypeModel):
|
||||||
proto_layer: LabeledComponents
|
def __init__(self, hparams, **kwargs):
|
||||||
|
|
||||||
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
|
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
distribution = hparams.get("distribution", None)
|
|
||||||
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||||
labels_initializer = kwargs.get("labels_initializer",
|
labels_initializer = kwargs.get("labels_initializer",
|
||||||
LabelsInitializer())
|
LabelsInitializer())
|
||||||
if not skip_proto_layer:
|
if prototypes_initializer is not None:
|
||||||
# when subclasses do not need a customized prototype layer
|
self.proto_layer = LabeledComponents(
|
||||||
if prototypes_initializer is not None:
|
distribution=self.hparams.distribution,
|
||||||
# when building a new model
|
components_initializer=prototypes_initializer,
|
||||||
self.proto_layer = LabeledComponents(
|
labels_initializer=labels_initializer,
|
||||||
distribution=distribution,
|
)
|
||||||
components_initializer=prototypes_initializer,
|
|
||||||
labels_initializer=labels_initializer,
|
|
||||||
)
|
|
||||||
proto_shape = self.proto_layer.components.shape[1:]
|
|
||||||
self.hparams["initialized_proto_shape"] = proto_shape
|
|
||||||
else:
|
|
||||||
# when restoring a checkpointed model
|
|
||||||
self.proto_layer = LabeledComponents(
|
|
||||||
distribution=distribution,
|
|
||||||
components_initializer=ZerosCompInitializer(
|
|
||||||
self.hparams["initialized_proto_shape"]),
|
|
||||||
)
|
|
||||||
self.competition_layer = WTAC()
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -167,14 +136,14 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
distances = self.compute_distances(x)
|
distances = self.compute_distances(x)
|
||||||
_, plabels = self.proto_layer()
|
plabels = self.proto_layer.labels
|
||||||
winning = stratified_min_pooling(distances, plabels)
|
winning = stratified_min_pooling(distances, plabels)
|
||||||
y_pred = F.softmin(winning, dim=1)
|
y_pred = torch.nn.functional.softmin(winning)
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
def predict_from_distances(self, distances):
|
def predict_from_distances(self, distances):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, plabels = self.proto_layer()
|
plabels = self.proto_layer.labels
|
||||||
y_pred = self.competition_layer(distances, plabels)
|
y_pred = self.competition_layer(distances, plabels)
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
@ -186,57 +155,36 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
def log_acc(self, distances, targets, tag):
|
def log_acc(self, distances, targets, tag):
|
||||||
preds = self.predict_from_distances(distances)
|
preds = self.predict_from_distances(distances)
|
||||||
accuracy = torchmetrics.functional.accuracy(
|
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||||
preds.int(),
|
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||||
targets.int(),
|
|
||||||
"multiclass",
|
|
||||||
num_classes=self.num_classes,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log(
|
self.log(tag,
|
||||||
tag,
|
accuracy,
|
||||||
accuracy,
|
on_step=False,
|
||||||
on_step=False,
|
on_epoch=True,
|
||||||
on_epoch=True,
|
prog_bar=True,
|
||||||
prog_bar=True,
|
logger=True)
|
||||||
logger=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
|
||||||
x, targets = batch
|
|
||||||
|
|
||||||
preds = self.predict(x)
|
|
||||||
accuracy = torchmetrics.functional.accuracy(
|
|
||||||
preds.int(),
|
|
||||||
targets.int(),
|
|
||||||
"multiclass",
|
|
||||||
num_classes=self.num_classes,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log("test_acc", accuracy)
|
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchMixin:
|
class ProtoTorchMixin(object):
|
||||||
"""All mixins are ProtoTorchMixins."""
|
"""All mixins are ProtoTorchMixins."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NonGradientMixin(ProtoTorchMixin):
|
class NonGradientMixin(ProtoTorchMixin):
|
||||||
"""Mixin for custom non-gradient optimization."""
|
"""Mixin for custom non-gradient optimization."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.automatic_optimization = False
|
self.automatic_optimization: Final = False
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||||
"""Mixin for models with image prototypes."""
|
"""Mixin for models with image prototypes."""
|
||||||
proto_layer: Components
|
@final
|
||||||
components: torch.Tensor
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
|
||||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||||
|
|
@ -1,30 +1,24 @@
|
|||||||
"""Lightning Callbacks."""
|
"""Lightning Callbacks."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from prototorch.core.initializers import LiteralCompInitializer
|
|
||||||
|
|
||||||
|
from ..core.components import Components
|
||||||
|
from ..core.initializers import LiteralCompInitializer
|
||||||
from .extras import ConnectionTopology
|
from .extras import ConnectionTopology
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from prototorch.models import GLVQ, GrowingNeuralGas
|
|
||||||
|
|
||||||
|
|
||||||
class PruneLoserPrototypes(pl.Callback):
|
class PruneLoserPrototypes(pl.Callback):
|
||||||
|
def __init__(self,
|
||||||
def __init__(
|
threshold=0.01,
|
||||||
self,
|
idle_epochs=10,
|
||||||
threshold=0.01,
|
prune_quota_per_epoch=-1,
|
||||||
idle_epochs=10,
|
frequency=1,
|
||||||
prune_quota_per_epoch=-1,
|
replace=False,
|
||||||
frequency=1,
|
prototypes_initializer=None,
|
||||||
replace=False,
|
verbose=False):
|
||||||
prototypes_initializer=None,
|
|
||||||
verbose=False,
|
|
||||||
):
|
|
||||||
self.threshold = threshold # minimum win ratio
|
self.threshold = threshold # minimum win ratio
|
||||||
self.idle_epochs = idle_epochs # epochs to wait before pruning
|
self.idle_epochs = idle_epochs # epochs to wait before pruning
|
||||||
self.prune_quota_per_epoch = prune_quota_per_epoch
|
self.prune_quota_per_epoch = prune_quota_per_epoch
|
||||||
@ -33,7 +27,7 @@ class PruneLoserPrototypes(pl.Callback):
|
|||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.prototypes_initializer = prototypes_initializer
|
self.prototypes_initializer = prototypes_initializer
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module: "GLVQ"):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
if (trainer.current_epoch + 1) < self.idle_epochs:
|
if (trainer.current_epoch + 1) < self.idle_epochs:
|
||||||
return None
|
return None
|
||||||
if (trainer.current_epoch + 1) % self.frequency:
|
if (trainer.current_epoch + 1) % self.frequency:
|
||||||
@ -48,44 +42,41 @@ class PruneLoserPrototypes(pl.Callback):
|
|||||||
prune_labels = prune_labels[:self.prune_quota_per_epoch]
|
prune_labels = prune_labels[:self.prune_quota_per_epoch]
|
||||||
|
|
||||||
if len(to_prune) > 0:
|
if len(to_prune) > 0:
|
||||||
logging.debug(f"\nPrototype win ratios: {ratios}")
|
if self.verbose:
|
||||||
logging.debug(f"Pruning prototypes at: {to_prune}")
|
print(f"\nPrototype win ratios: {ratios}")
|
||||||
logging.debug(f"Corresponding labels are: {prune_labels.tolist()}")
|
print(f"Pruning prototypes at: {to_prune}")
|
||||||
|
print(f"Corresponding labels are: {prune_labels.tolist()}")
|
||||||
cur_num_protos = pl_module.num_prototypes
|
cur_num_protos = pl_module.num_prototypes
|
||||||
pl_module.remove_prototypes(indices=to_prune)
|
pl_module.remove_prototypes(indices=to_prune)
|
||||||
|
|
||||||
if self.replace:
|
if self.replace:
|
||||||
labels, counts = torch.unique(prune_labels,
|
labels, counts = torch.unique(prune_labels,
|
||||||
sorted=True,
|
sorted=True,
|
||||||
return_counts=True)
|
return_counts=True)
|
||||||
distribution = dict(zip(labels.tolist(), counts.tolist()))
|
distribution = dict(zip(labels.tolist(), counts.tolist()))
|
||||||
|
if self.verbose:
|
||||||
logging.info(f"Re-adding pruned prototypes...")
|
print(f"Re-adding pruned prototypes...")
|
||||||
logging.debug(f"distribution={distribution}")
|
print(f"{distribution=}")
|
||||||
|
|
||||||
pl_module.add_prototypes(
|
pl_module.add_prototypes(
|
||||||
distribution=distribution,
|
distribution=distribution,
|
||||||
components_initializer=self.prototypes_initializer)
|
components_initializer=self.prototypes_initializer)
|
||||||
new_num_protos = pl_module.num_prototypes
|
new_num_protos = pl_module.num_prototypes
|
||||||
|
if self.verbose:
|
||||||
logging.info(f"`num_prototypes` changed from {cur_num_protos} "
|
print(f"`num_prototypes` changed from {cur_num_protos} "
|
||||||
f"to {new_num_protos}.")
|
f"to {new_num_protos}.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class PrototypeConvergence(pl.Callback):
|
class PrototypeConvergence(pl.Callback):
|
||||||
|
|
||||||
def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False):
|
def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False):
|
||||||
self.min_delta = min_delta
|
self.min_delta = min_delta
|
||||||
self.idle_epochs = idle_epochs # epochs to wait
|
self.idle_epochs = idle_epochs # epochs to wait
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
if (trainer.current_epoch + 1) < self.idle_epochs:
|
if (trainer.current_epoch + 1) < self.idle_epochs:
|
||||||
return None
|
return None
|
||||||
|
if self.verbose:
|
||||||
logging.info("Stopping...")
|
print("Stopping...")
|
||||||
# TODO
|
# TODO
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -98,21 +89,16 @@ class GNGCallback(pl.Callback):
|
|||||||
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reduction=0.1, freq=10):
|
def __init__(self, reduction=0.1, freq=10):
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.freq = freq
|
self.freq = freq
|
||||||
|
|
||||||
def on_train_epoch_end(
|
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
|
||||||
self,
|
|
||||||
trainer: pl.Trainer,
|
|
||||||
pl_module: "GrowingNeuralGas",
|
|
||||||
):
|
|
||||||
if (trainer.current_epoch + 1) % self.freq == 0:
|
if (trainer.current_epoch + 1) % self.freq == 0:
|
||||||
# Get information
|
# Get information
|
||||||
errors = pl_module.errors
|
errors = pl_module.errors
|
||||||
topology: ConnectionTopology = pl_module.topology_layer
|
topology: ConnectionTopology = pl_module.topology_layer
|
||||||
components = pl_module.proto_layer.components
|
components: Components = pl_module.proto_layer.components
|
||||||
|
|
||||||
# Insertion point
|
# Insertion point
|
||||||
worst = torch.argmax(errors)
|
worst = torch.argmax(errors)
|
||||||
@ -132,9 +118,8 @@ class GNGCallback(pl.Callback):
|
|||||||
|
|
||||||
# Add component
|
# Add component
|
||||||
pl_module.proto_layer.add_components(
|
pl_module.proto_layer.add_components(
|
||||||
1,
|
None,
|
||||||
initializer=LiteralCompInitializer(new_component.unsqueeze(0)),
|
initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
|
||||||
)
|
|
||||||
|
|
||||||
# Adjust Topology
|
# Adjust Topology
|
||||||
topology.add_prototype()
|
topology.add_prototype()
|
||||||
@ -149,4 +134,4 @@ class GNGCallback(pl.Callback):
|
|||||||
pl_module.errors[
|
pl_module.errors[
|
||||||
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
||||||
|
|
||||||
trainer.strategy.setup_optimizers(trainer)
|
trainer.accelerator_backend.setup_optimizers(trainer)
|
@ -1,21 +1,20 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.core.competitions import CBCC
|
|
||||||
from prototorch.core.components import ReasoningComponents
|
|
||||||
from prototorch.core.initializers import RandomReasoningsInitializer
|
|
||||||
from prototorch.core.losses import MarginLoss
|
|
||||||
from prototorch.core.similarities import euclidean_similarity
|
|
||||||
from prototorch.nn.wrappers import LambdaLayer
|
|
||||||
|
|
||||||
|
from ..core.competitions import CBCC
|
||||||
|
from ..core.components import ReasoningComponents
|
||||||
|
from ..core.initializers import RandomReasoningsInitializer
|
||||||
|
from ..core.losses import MarginLoss
|
||||||
|
from ..core.similarities import euclidean_similarity
|
||||||
|
from ..nn.wrappers import LambdaLayer
|
||||||
from .abstract import ImagePrototypesMixin
|
from .abstract import ImagePrototypesMixin
|
||||||
from .glvq import SiameseGLVQ
|
from .glvq import SiameseGLVQ
|
||||||
|
|
||||||
|
|
||||||
class CBC(SiameseGLVQ):
|
class CBC(SiameseGLVQ):
|
||||||
"""Classification-By-Components."""
|
"""Classification-By-Components."""
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, skip_proto_layer=True, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
||||||
components_initializer = kwargs.get("components_initializer", None)
|
components_initializer = kwargs.get("components_initializer", None)
|
||||||
@ -44,7 +43,7 @@ class CBC(SiameseGLVQ):
|
|||||||
probs = self.competition_layer(detections, reasonings)
|
probs = self.competition_layer(detections, reasonings)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_pred = self(x)
|
y_pred = self(x)
|
||||||
num_classes = self.num_classes
|
num_classes = self.num_classes
|
||||||
@ -52,23 +51,17 @@ class CBC(SiameseGLVQ):
|
|||||||
loss = self.loss(y_pred, y_true).mean()
|
loss = self.loss(y_pred, y_true).mean()
|
||||||
return y_pred, loss
|
return y_pred, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
y_pred, train_loss = self.shared_step(batch, batch_idx)
|
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
preds = torch.argmax(y_pred, dim=1)
|
preds = torch.argmax(y_pred, dim=1)
|
||||||
accuracy = torchmetrics.functional.accuracy(
|
accuracy = torchmetrics.functional.accuracy(preds.int(),
|
||||||
preds.int(),
|
batch[1].int())
|
||||||
batch[1].int(),
|
self.log("train_acc",
|
||||||
"multiclass",
|
accuracy,
|
||||||
num_classes=self.num_classes,
|
on_step=False,
|
||||||
)
|
on_epoch=True,
|
||||||
self.log(
|
prog_bar=True,
|
||||||
"train_acc",
|
logger=True)
|
||||||
accuracy,
|
|
||||||
on_step=False,
|
|
||||||
on_epoch=True,
|
|
||||||
prog_bar=True,
|
|
||||||
logger=True,
|
|
||||||
)
|
|
||||||
return train_loss
|
return train_loss
|
||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x):
|
123
prototorch/models/data.py
Normal file
123
prototorch/models/data.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
"""Prototorch Data Modules
|
||||||
|
|
||||||
|
This allows to store the used dataset inside a Lightning Module.
|
||||||
|
Mainly used for PytorchLightningCLI configurations.
|
||||||
|
"""
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from torch.utils.data import DataLoader, Dataset, random_split
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
|
|
||||||
|
# MNIST
|
||||||
|
class MNISTDataModule(pl.LightningDataModule):
|
||||||
|
def __init__(self, batch_size=32):
|
||||||
|
super().__init__()
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
# Download mnist dataset as side-effect, only called on the first cpu
|
||||||
|
def prepare_data(self):
|
||||||
|
MNIST("~/datasets", train=True, download=True)
|
||||||
|
MNIST("~/datasets", train=False, download=True)
|
||||||
|
|
||||||
|
# called for every GPU/machine (assigning state is OK)
|
||||||
|
def setup(self, stage=None):
|
||||||
|
# Transforms
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
# Split dataset
|
||||||
|
if stage in (None, "fit"):
|
||||||
|
mnist_train = MNIST("~/datasets", train=True, transform=transform)
|
||||||
|
self.mnist_train, self.mnist_val = random_split(
|
||||||
|
mnist_train,
|
||||||
|
[55000, 5000],
|
||||||
|
)
|
||||||
|
if stage == (None, "test"):
|
||||||
|
self.mnist_test = MNIST(
|
||||||
|
"~/datasets",
|
||||||
|
train=False,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
def train_dataloader(self):
|
||||||
|
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
|
||||||
|
return mnist_train
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
|
||||||
|
return mnist_val
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
|
||||||
|
return mnist_test
|
||||||
|
|
||||||
|
|
||||||
|
# def train_on_mnist(batch_size=256) -> type:
|
||||||
|
# class DataClass(pl.LightningModule):
|
||||||
|
# datamodule = MNISTDataModule(batch_size=batch_size)
|
||||||
|
|
||||||
|
# def __init__(self, *args, **kwargs):
|
||||||
|
# prototype_initializer = kwargs.pop(
|
||||||
|
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
|
||||||
|
# super().__init__(*args,
|
||||||
|
# prototype_initializer=prototype_initializer,
|
||||||
|
# **kwargs)
|
||||||
|
|
||||||
|
# dc: Type[DataClass] = DataClass
|
||||||
|
# return dc
|
||||||
|
|
||||||
|
|
||||||
|
# ABSTRACT
|
||||||
|
class GeneralDataModule(pl.LightningDataModule):
|
||||||
|
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.train_dataset = dataset
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def train_dataloader(self) -> DataLoader:
|
||||||
|
return DataLoader(self.train_dataset, batch_size=self.batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
|
||||||
|
# class DataClass(pl.LightningModule):
|
||||||
|
# datamodule = GeneralDataModule(dataset, batch_size)
|
||||||
|
# datashape = dataset[0][0].shape
|
||||||
|
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
|
||||||
|
|
||||||
|
# def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
# prototype_initializer = kwargs.pop(
|
||||||
|
# "prototype_initializer",
|
||||||
|
# pt.components.Zeros(self.datashape),
|
||||||
|
# )
|
||||||
|
# super().__init__(*args,
|
||||||
|
# prototype_initializer=prototype_initializer,
|
||||||
|
# **kwargs)
|
||||||
|
|
||||||
|
# return DataClass
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# from prototorch.models import GLVQ
|
||||||
|
|
||||||
|
# demo_dataset = pt.datasets.Iris()
|
||||||
|
|
||||||
|
# TrainingClass: Type = train_on_dataset(demo_dataset)
|
||||||
|
|
||||||
|
# class DemoGLVQ(TrainingClass, GLVQ):
|
||||||
|
# """Model Definition."""
|
||||||
|
|
||||||
|
# # Hyperparameters
|
||||||
|
# hparams = dict(
|
||||||
|
# distribution={
|
||||||
|
# "num_classes": 3,
|
||||||
|
# "prototypes_per_class": 4
|
||||||
|
# },
|
||||||
|
# lr=0.01,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# initialized = DemoGLVQ(hparams)
|
||||||
|
# print(initialized)
|
@ -5,7 +5,8 @@ Modules not yet available in prototorch go here temporarily.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.core.similarities import gaussian
|
|
||||||
|
from ..core.similarities import gaussian
|
||||||
|
|
||||||
|
|
||||||
def rank_scaled_gaussian(distances, lambd):
|
def rank_scaled_gaussian(distances, lambd):
|
||||||
@ -14,46 +15,7 @@ def rank_scaled_gaussian(distances, lambd):
|
|||||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||||
|
|
||||||
|
|
||||||
def orthogonalization(tensors):
|
|
||||||
"""Orthogonalization via polar decomposition """
|
|
||||||
u, _, v = torch.svd(tensors, compute_uv=True)
|
|
||||||
u_shape = tuple(list(u.shape))
|
|
||||||
v_shape = tuple(list(v.shape))
|
|
||||||
|
|
||||||
# reshape to (num x N x M)
|
|
||||||
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
|
|
||||||
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
|
|
||||||
|
|
||||||
out = u @ v.permute([0, 2, 1])
|
|
||||||
|
|
||||||
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def ltangent_distance(x, y, omegas):
|
|
||||||
r"""Localized Tangent distance.
|
|
||||||
Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
|
|
||||||
Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
|
|
||||||
|
|
||||||
:param `torch.tensor` omegas: Three dimensional matrix
|
|
||||||
:rtype: `torch.tensor`
|
|
||||||
"""
|
|
||||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
|
||||||
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
|
||||||
omegas, omegas.permute([0, 2, 1]))
|
|
||||||
projected_x = x @ p
|
|
||||||
projected_y = torch.diagonal(y @ p).T
|
|
||||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
|
||||||
batchwise_difference = expanded_y - projected_x
|
|
||||||
differences_squared = batchwise_difference**2
|
|
||||||
distances = torch.sqrt(torch.sum(differences_squared, dim=2))
|
|
||||||
distances = distances.permute(1, 0)
|
|
||||||
return distances
|
|
||||||
|
|
||||||
|
|
||||||
class GaussianPrior(torch.nn.Module):
|
class GaussianPrior(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, variance):
|
def __init__(self, variance):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.variance = variance
|
self.variance = variance
|
||||||
@ -63,7 +25,6 @@ class GaussianPrior(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class RankScaledGaussianPrior(torch.nn.Module):
|
class RankScaledGaussianPrior(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, lambd):
|
def __init__(self, lambd):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lambd = lambd
|
self.lambd = lambd
|
||||||
@ -73,7 +34,6 @@ class RankScaledGaussianPrior(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ConnectionTopology(torch.nn.Module):
|
class ConnectionTopology(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, agelimit, num_prototypes):
|
def __init__(self, agelimit, num_prototypes):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.agelimit = agelimit
|
self.agelimit = agelimit
|
@ -1,29 +1,19 @@
|
|||||||
"""Models based on the GLVQ framework."""
|
"""Models based on the GLVQ framework."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.core.competitions import wtac
|
|
||||||
from prototorch.core.distances import (
|
|
||||||
lomega_distance,
|
|
||||||
omega_distance,
|
|
||||||
squared_euclidean_distance,
|
|
||||||
)
|
|
||||||
from prototorch.core.initializers import EyeLinearTransformInitializer
|
|
||||||
from prototorch.core.losses import (
|
|
||||||
GLVQLoss,
|
|
||||||
lvq1_loss,
|
|
||||||
lvq21_loss,
|
|
||||||
)
|
|
||||||
from prototorch.core.transforms import LinearTransform
|
|
||||||
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from ..core.competitions import wtac
|
||||||
|
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
||||||
|
from ..core.initializers import EyeTransformInitializer
|
||||||
|
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
|
||||||
|
from ..core.transforms import LinearTransform
|
||||||
|
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||||
from .extras import ltangent_distance, orthogonalization
|
|
||||||
|
|
||||||
|
|
||||||
class GLVQ(SupervisedPrototypeModel):
|
class GLVQ(SupervisedPrototypeModel):
|
||||||
"""Generalized Learning Vector Quantization."""
|
"""Generalized Learning Vector Quantization."""
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
@ -34,21 +24,17 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
self.loss = GLVQLoss(
|
self.loss = GLVQLoss(
|
||||||
margin=self.hparams["margin"],
|
margin=self.hparams.margin,
|
||||||
transfer_fn=self.hparams["transfer_fn"],
|
transfer_fn=self.hparams.transfer_fn,
|
||||||
beta=self.hparams["transfer_beta"],
|
beta=self.hparams.transfer_beta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# def on_save_checkpoint(self, checkpoint):
|
|
||||||
# if "prototype_win_ratios" in checkpoint["state_dict"]:
|
|
||||||
# del checkpoint["state_dict"]["prototype_win_ratios"]
|
|
||||||
|
|
||||||
def initialize_prototype_win_ratios(self):
|
def initialize_prototype_win_ratios(self):
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"prototype_win_ratios",
|
"prototype_win_ratios",
|
||||||
torch.zeros(self.num_prototypes, device=self.device))
|
torch.zeros(self.num_prototypes, device=self.device))
|
||||||
|
|
||||||
def on_train_epoch_start(self):
|
def on_epoch_start(self):
|
||||||
self.initialize_prototype_win_ratios()
|
self.initialize_prototype_win_ratios()
|
||||||
|
|
||||||
def log_prototype_win_ratios(self, distances):
|
def log_prototype_win_ratios(self, distances):
|
||||||
@ -66,15 +52,15 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
prototype_wr,
|
prototype_wr,
|
||||||
])
|
])
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x)
|
out = self.compute_distances(x)
|
||||||
_, plabels = self.proto_layer()
|
plabels = self.proto_layer.labels
|
||||||
loss = self.loss(out, y, plabels)
|
loss = self.loss(out, y, plabels)
|
||||||
return out, loss
|
return out, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
out, train_loss = self.shared_step(batch, batch_idx)
|
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
self.log_prototype_win_ratios(out)
|
self.log_prototype_win_ratios(out)
|
||||||
self.log("train_loss", train_loss)
|
self.log("train_loss", train_loss)
|
||||||
self.log_acc(out, batch[-1], tag="train_acc")
|
self.log_acc(out, batch[-1], tag="train_acc")
|
||||||
@ -99,6 +85,10 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
test_loss += batch_loss.item()
|
test_loss += batch_loss.item()
|
||||||
self.log("test_loss", test_loss)
|
self.log("test_loss", test_loss)
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
class SiameseGLVQ(GLVQ):
|
class SiameseGLVQ(GLVQ):
|
||||||
"""GLVQ in a Siamese setting.
|
"""GLVQ in a Siamese setting.
|
||||||
@ -108,7 +98,6 @@ class SiameseGLVQ(GLVQ):
|
|||||||
transformation pipeline are only learned from the inputs.
|
transformation pipeline are only learned from the inputs.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
hparams,
|
hparams,
|
||||||
backbone=torch.nn.Identity(),
|
backbone=torch.nn.Identity(),
|
||||||
@ -119,17 +108,32 @@ class SiameseGLVQ(GLVQ):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.both_path_gradients = both_path_gradients
|
self.both_path_gradients = both_path_gradients
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||||
|
lr=self.hparams.proto_lr)
|
||||||
|
# Only add a backbone optimizer if backbone has trainable parameters
|
||||||
|
if (bb_params := list(self.backbone.parameters())):
|
||||||
|
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
|
||||||
|
optimizers = [proto_opt, bb_opt]
|
||||||
|
else:
|
||||||
|
optimizers = [proto_opt]
|
||||||
|
if self.lr_scheduler is not None:
|
||||||
|
schedulers = []
|
||||||
|
for optimizer in optimizers:
|
||||||
|
scheduler = self.lr_scheduler(optimizer,
|
||||||
|
**self.lr_scheduler_kwargs)
|
||||||
|
schedulers.append(scheduler)
|
||||||
|
return optimizers, schedulers
|
||||||
|
else:
|
||||||
|
return optimizers
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
|
||||||
|
|
||||||
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
|
|
||||||
latent_protos = self.backbone(protos)
|
latent_protos = self.backbone(protos)
|
||||||
self.backbone.requires_grad_(bb_grad)
|
self.backbone.requires_grad_(True)
|
||||||
|
|
||||||
distances = self.distance_layer(latent_x, latent_protos)
|
distances = self.distance_layer(latent_x, latent_protos)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
@ -159,7 +163,6 @@ class LVQMLN(SiameseGLVQ):
|
|||||||
rather in the embedding space.
|
rather in the embedding space.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
latent_protos, _ = self.proto_layer()
|
latent_protos, _ = self.proto_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
@ -175,22 +178,17 @@ class GRLVQ(SiameseGLVQ):
|
|||||||
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_relevances: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
relevances = torch.ones(self.hparams["input_dim"], device=self.device)
|
relevances = torch.ones(self.hparams.input_dim, device=self.device)
|
||||||
self.register_parameter("_relevances", Parameter(relevances))
|
self.register_parameter("_relevances", Parameter(relevances))
|
||||||
|
|
||||||
# Override the backbone
|
# Override the backbone
|
||||||
self.backbone = LambdaLayer(self._apply_relevances,
|
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
|
||||||
name="relevance scaling")
|
name="relevance scaling")
|
||||||
|
|
||||||
def _apply_relevances(self, x):
|
|
||||||
return x @ torch.diag(self._relevances)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def relevance_profile(self):
|
def relevance_profile(self):
|
||||||
return self._relevances.detach().cpu()
|
return self._relevances.detach().cpu()
|
||||||
@ -205,16 +203,15 @@ class SiameseGMLVQ(SiameseGLVQ):
|
|||||||
Implemented as a Siamese network with a linear transformation backbone.
|
Implemented as a Siamese network with a linear transformation backbone.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Override the backbone
|
# Override the backbone
|
||||||
omega_initializer = kwargs.get("omega_initializer",
|
omega_initializer = kwargs.get("omega_initializer",
|
||||||
EyeLinearTransformInitializer())
|
EyeTransformInitializer())
|
||||||
self.backbone = LinearTransform(
|
self.backbone = LinearTransform(
|
||||||
self.hparams["input_dim"],
|
self.hparams.input_dim,
|
||||||
self.hparams["latent_dim"],
|
self.hparams.output_dim,
|
||||||
initializer=omega_initializer,
|
initializer=omega_initializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -224,7 +221,7 @@ class SiameseGMLVQ(SiameseGLVQ):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def lambda_matrix(self):
|
def lambda_matrix(self):
|
||||||
omega = self.backbone.weights # (input_dim, latent_dim)
|
omega = self.backbone.weight # (input_dim, latent_dim)
|
||||||
lam = omega @ omega.T
|
lam = omega @ omega.T
|
||||||
return lam.detach().cpu()
|
return lam.detach().cpu()
|
||||||
|
|
||||||
@ -236,31 +233,23 @@ class GMLVQ(GLVQ):
|
|||||||
function. This makes it easier to implement a localized variant.
|
function. This makes it easier to implement a localized variant.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Parameters
|
|
||||||
_omega: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
omega_initializer = kwargs.get("omega_initializer",
|
omega_initializer = kwargs.get("omega_initializer",
|
||||||
EyeLinearTransformInitializer())
|
EyeTransformInitializer())
|
||||||
omega = omega_initializer.generate(self.hparams["input_dim"],
|
omega = omega_initializer.generate(self.hparams.input_dim,
|
||||||
self.hparams["latent_dim"])
|
self.hparams.latent_dim)
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
self.backbone = LambdaLayer(lambda x: x @ self._omega,
|
||||||
|
name="omega matrix")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrix(self):
|
def omega_matrix(self):
|
||||||
return self._omega.detach().cpu()
|
return self._omega.detach().cpu()
|
||||||
|
|
||||||
@property
|
|
||||||
def lambda_matrix(self):
|
|
||||||
omega = self._omega.detach() # (input_dim, latent_dim)
|
|
||||||
lam = omega @ omega.T
|
|
||||||
return lam.detach().cpu()
|
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
distances = self.distance_layer(x, protos, self._omega)
|
distances = self.distance_layer(x, protos, self._omega)
|
||||||
@ -272,7 +261,6 @@ class GMLVQ(GLVQ):
|
|||||||
|
|
||||||
class LGMLVQ(GMLVQ):
|
class LGMLVQ(GMLVQ):
|
||||||
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
distance_fn = kwargs.pop("distance_fn", lomega_distance)
|
distance_fn = kwargs.pop("distance_fn", lomega_distance)
|
||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
@ -280,59 +268,15 @@ class LGMLVQ(GMLVQ):
|
|||||||
# Re-register `_omega` to override the one from the super class.
|
# Re-register `_omega` to override the one from the super class.
|
||||||
omega = torch.randn(
|
omega = torch.randn(
|
||||||
self.num_prototypes,
|
self.num_prototypes,
|
||||||
self.hparams["input_dim"],
|
self.hparams.input_dim,
|
||||||
self.hparams["latent_dim"],
|
self.hparams.latent_dim,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
|
||||||
|
|
||||||
class GTLVQ(LGMLVQ):
|
|
||||||
"""Localized and Generalized Tangent Learning Vector Quantization."""
|
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
|
|
||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
|
||||||
|
|
||||||
omega_initializer = kwargs.get("omega_initializer")
|
|
||||||
|
|
||||||
if omega_initializer is not None:
|
|
||||||
subspace = omega_initializer.generate(
|
|
||||||
self.hparams["input_dim"],
|
|
||||||
self.hparams["latent_dim"],
|
|
||||||
)
|
|
||||||
omega = torch.repeat_interleave(
|
|
||||||
subspace.unsqueeze(0),
|
|
||||||
self.num_prototypes,
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
omega = torch.rand(
|
|
||||||
self.num_prototypes,
|
|
||||||
self.hparams["input_dim"],
|
|
||||||
self.hparams["latent_dim"],
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-register `_omega` to override the one from the super class.
|
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
|
||||||
|
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
|
||||||
with torch.no_grad():
|
|
||||||
self._omega.copy_(orthogonalization(self._omega))
|
|
||||||
|
|
||||||
|
|
||||||
class SiameseGTLVQ(SiameseGLVQ, GTLVQ):
|
|
||||||
"""Generalized Tangent Learning Vector Quantization.
|
|
||||||
|
|
||||||
Implemented as a Siamese network with a linear transformation backbone.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class GLVQ1(GLVQ):
|
class GLVQ1(GLVQ):
|
||||||
"""Generalized Learning Vector Quantization 1."""
|
"""Generalized Learning Vector Quantization 1."""
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.loss = LossLayer(lvq1_loss)
|
self.loss = LossLayer(lvq1_loss)
|
||||||
@ -341,7 +285,6 @@ class GLVQ1(GLVQ):
|
|||||||
|
|
||||||
class GLVQ21(GLVQ):
|
class GLVQ21(GLVQ):
|
||||||
"""Generalized Learning Vector Quantization 2.1."""
|
"""Generalized Learning Vector Quantization 2.1."""
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.loss = LossLayer(lvq21_loss)
|
self.loss = LossLayer(lvq21_loss)
|
||||||
@ -364,18 +307,3 @@ class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
|
|||||||
after updates.
|
after updates.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ImageGTLVQ(ImagePrototypesMixin, GTLVQ):
|
|
||||||
"""GTLVQ for training on image data.
|
|
||||||
|
|
||||||
GTLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
|
||||||
after updates.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
|
||||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
|
||||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
|
||||||
with torch.no_grad():
|
|
||||||
self._omega.copy_(orthogonalization(self._omega))
|
|
@ -2,22 +2,17 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from prototorch.core.competitions import KNNC
|
from ..core.competitions import KNNC
|
||||||
from prototorch.core.components import LabeledComponents
|
from ..core.components import LabeledComponents
|
||||||
from prototorch.core.initializers import (
|
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
|
||||||
LiteralCompInitializer,
|
from ..utils.utils import parse_data_arg
|
||||||
LiteralLabelsInitializer,
|
|
||||||
)
|
|
||||||
from prototorch.utils.utils import parse_data_arg
|
|
||||||
|
|
||||||
from .abstract import SupervisedPrototypeModel
|
from .abstract import SupervisedPrototypeModel
|
||||||
|
|
||||||
|
|
||||||
class KNN(SupervisedPrototypeModel):
|
class KNN(SupervisedPrototypeModel):
|
||||||
"""K-Nearest-Neighbors classification algorithm."""
|
"""K-Nearest-Neighbors classification algorithm."""
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, skip_proto_layer=True, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Default hparams
|
# Default hparams
|
||||||
self.hparams.setdefault("k", 1)
|
self.hparams.setdefault("k", 1)
|
||||||
@ -29,15 +24,18 @@ class KNN(SupervisedPrototypeModel):
|
|||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
distribution=len(data) * [1],
|
distribution=[],
|
||||||
components_initializer=LiteralCompInitializer(data),
|
components_initializer=LiteralCompInitializer(data),
|
||||||
labels_initializer=LiteralLabelsInitializer(targets))
|
labels_initializer=LiteralLabelsInitializer(targets))
|
||||||
self.competition_layer = KNNC(k=self.hparams.k)
|
self.competition_layer = KNNC(k=self.hparams.k)
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
return 1 # skip training step
|
return 1 # skip training step
|
||||||
|
|
||||||
def on_train_batch_start(self, train_batch, batch_idx):
|
def on_train_batch_start(self,
|
||||||
|
train_batch,
|
||||||
|
batch_idx,
|
||||||
|
dataloader_idx=None):
|
||||||
warnings.warn("k-NN has no training, skipping!")
|
warnings.warn("k-NN has no training, skipping!")
|
||||||
return -1
|
return -1
|
||||||
|
|
@ -1,20 +1,18 @@
|
|||||||
"""LVQ models that are optimized using non-gradient methods."""
|
"""LVQ models that are optimized using non-gradient methods."""
|
||||||
|
|
||||||
import logging
|
from ..core.losses import _get_dp_dm
|
||||||
|
from ..nn.activations import get_activation
|
||||||
from prototorch.core.losses import _get_dp_dm
|
from ..nn.wrappers import LambdaLayer
|
||||||
from prototorch.nn.activations import get_activation
|
|
||||||
from prototorch.nn.wrappers import LambdaLayer
|
|
||||||
|
|
||||||
from .abstract import NonGradientMixin
|
from .abstract import NonGradientMixin
|
||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
|
||||||
|
|
||||||
class LVQ1(NonGradientMixin, GLVQ):
|
class LVQ1(NonGradientMixin, GLVQ):
|
||||||
"""Learning Vector Quantization 1."""
|
"""Learning Vector Quantization 1."""
|
||||||
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
|
protos = self.proto_layer.components
|
||||||
|
plabels = self.proto_layer.labels
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
|
||||||
protos, plables = self.proto_layer()
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self.compute_distances(x)
|
dis = self.compute_distances(x)
|
||||||
# TODO Vectorized implementation
|
# TODO Vectorized implementation
|
||||||
@ -32,8 +30,8 @@ class LVQ1(NonGradientMixin, GLVQ):
|
|||||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||||
strict=False)
|
strict=False)
|
||||||
|
|
||||||
logging.debug(f"dis={dis}")
|
print(f"{dis=}")
|
||||||
logging.debug(f"y={y}")
|
print(f"{y=}")
|
||||||
# Logging
|
# Logging
|
||||||
self.log_acc(dis, y, tag="train_acc")
|
self.log_acc(dis, y, tag="train_acc")
|
||||||
|
|
||||||
@ -42,9 +40,9 @@ class LVQ1(NonGradientMixin, GLVQ):
|
|||||||
|
|
||||||
class LVQ21(NonGradientMixin, GLVQ):
|
class LVQ21(NonGradientMixin, GLVQ):
|
||||||
"""Learning Vector Quantization 2.1."""
|
"""Learning Vector Quantization 2.1."""
|
||||||
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
def training_step(self, train_batch, batch_idx):
|
protos = self.proto_layer.components
|
||||||
protos, plabels = self.proto_layer()
|
plabels = self.proto_layer.labels
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self.compute_distances(x)
|
dis = self.compute_distances(x)
|
||||||
@ -75,8 +73,8 @@ class MedianLVQ(NonGradientMixin, GLVQ):
|
|||||||
# TODO Avoid computing distances over and over
|
# TODO Avoid computing distances over and over
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, hparams, verbose=True, **kwargs):
|
||||||
def __init__(self, hparams, **kwargs):
|
self.verbose = verbose
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
self.transfer_layer = LambdaLayer(
|
self.transfer_layer = LambdaLayer(
|
||||||
@ -100,8 +98,9 @@ class MedianLVQ(NonGradientMixin, GLVQ):
|
|||||||
lower_bound = (gamma * f.log()).sum()
|
lower_bound = (gamma * f.log()).sum()
|
||||||
return lower_bound
|
return lower_bound
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
protos, plabels = self.proto_layer()
|
protos = self.proto_layer.components
|
||||||
|
plabels = self.proto_layer.labels
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self.compute_distances(x)
|
dis = self.compute_distances(x)
|
||||||
@ -117,7 +116,8 @@ class MedianLVQ(NonGradientMixin, GLVQ):
|
|||||||
_protos[i] = xk
|
_protos[i] = xk
|
||||||
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
|
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
|
||||||
if _lower_bound > lower_bound:
|
if _lower_bound > lower_bound:
|
||||||
logging.debug(f"Updating prototype {i} to data {k}...")
|
if self.verbose:
|
||||||
|
print(f"Updating prototype {i} to data {k}...")
|
||||||
self.proto_layer.load_state_dict({"_components": _protos},
|
self.proto_layer.load_state_dict({"_components": _protos},
|
||||||
strict=False)
|
strict=False)
|
||||||
break
|
break
|
58
prototorch/models/nam.py
Normal file
58
prototorch/models/nam.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
"""ProtoTorch Neural Additive Model."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchmetrics
|
||||||
|
|
||||||
|
from .abstract import ProtoTorchBolt
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryNAM(ProtoTorchBolt):
|
||||||
|
"""Neural Additive Model for binary classification.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/2004.13912
|
||||||
|
Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, hparams: dict, extractors: torch.nn.ModuleList,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
|
# Default hparams
|
||||||
|
self.hparams.setdefault("threshold", 0.5)
|
||||||
|
|
||||||
|
self.extractors = extractors
|
||||||
|
self.linear = torch.nn.Linear(in_features=len(extractors),
|
||||||
|
out_features=1,
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
def extract(self, x):
|
||||||
|
"""Apply the local extractors batch-wise on features."""
|
||||||
|
out = torch.zeros_like(x)
|
||||||
|
for j in range(x.shape[1]):
|
||||||
|
out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.extract(x)
|
||||||
|
x = self.linear(x)
|
||||||
|
return torch.sigmoid(x)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
|
x, y = batch
|
||||||
|
preds = self(x).squeeze()
|
||||||
|
train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float())
|
||||||
|
self.log("train_loss", train_loss)
|
||||||
|
accuracy = torchmetrics.functional.accuracy(preds.int(), y.int())
|
||||||
|
self.log("train_acc",
|
||||||
|
accuracy,
|
||||||
|
on_step=False,
|
||||||
|
on_epoch=True,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True)
|
||||||
|
return train_loss
|
||||||
|
|
||||||
|
def predict(self, x):
|
||||||
|
out = self(x)
|
||||||
|
pred = torch.zeros_like(out, device=self.device)
|
||||||
|
pred[out > self.hparams.threshold] = 1
|
||||||
|
return pred
|
@ -1,30 +1,25 @@
|
|||||||
"""Probabilistic GLVQ methods"""
|
"""Probabilistic GLVQ methods"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.core.losses import nllr_loss, rslvq_loss
|
|
||||||
from prototorch.core.pooling import (
|
|
||||||
stratified_min_pooling,
|
|
||||||
stratified_sum_pooling,
|
|
||||||
)
|
|
||||||
from prototorch.nn.wrappers import LossLayer
|
|
||||||
|
|
||||||
|
from ..core.losses import nllr_loss, rslvq_loss
|
||||||
|
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||||
|
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||||
from .extras import GaussianPrior, RankScaledGaussianPrior
|
from .extras import GaussianPrior, RankScaledGaussianPrior
|
||||||
from .glvq import GLVQ, SiameseGMLVQ
|
from .glvq import GLVQ, SiameseGMLVQ
|
||||||
|
|
||||||
|
|
||||||
class CELVQ(GLVQ):
|
class CELVQ(GLVQ):
|
||||||
"""Cross-Entropy Learning Vector Quantization."""
|
"""Cross-Entropy Learning Vector Quantization."""
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
self.loss = torch.nn.CrossEntropyLoss()
|
self.loss = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x) # [None, num_protos]
|
out = self.compute_distances(x) # [None, num_protos]
|
||||||
_, plabels = self.proto_layer()
|
plabels = self.proto_layer.labels
|
||||||
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
|
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
|
||||||
probs = -1.0 * winning
|
probs = -1.0 * winning
|
||||||
batch_loss = self.loss(probs, y.long())
|
batch_loss = self.loss(probs, y.long())
|
||||||
@ -33,28 +28,20 @@ class CELVQ(GLVQ):
|
|||||||
|
|
||||||
|
|
||||||
class ProbabilisticLVQ(GLVQ):
|
class ProbabilisticLVQ(GLVQ):
|
||||||
|
|
||||||
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
|
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
|
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||||
self.rejection_confidence = rejection_confidence
|
self.rejection_confidence = rejection_confidence
|
||||||
self._conditional_distribution = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
distances = self.compute_distances(x)
|
distances = self.compute_distances(x)
|
||||||
|
|
||||||
conditional = self.conditional_distribution(distances)
|
conditional = self.conditional_distribution(distances)
|
||||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
posterior = conditional * prior
|
posterior = conditional * prior
|
||||||
|
|
||||||
plabels = self.proto_layer._labels
|
plabels = self.proto_layer._labels
|
||||||
if isinstance(plabels, torch.LongTensor) or isinstance(
|
y_pred = stratified_sum_pooling(posterior, plabels)
|
||||||
plabels, torch.cuda.LongTensor): # type: ignore
|
|
||||||
y_pred = stratified_sum_pooling(posterior, plabels) # type: ignore
|
|
||||||
else:
|
|
||||||
raise ValueError("Labels must be LongTensor.")
|
|
||||||
|
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x):
|
||||||
@ -63,46 +50,27 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
prediction[confidence < self.rejection_confidence] = -1
|
prediction[confidence < self.rejection_confidence] = -1
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.forward(x)
|
out = self.forward(x)
|
||||||
_, plabels = self.proto_layer()
|
plabels = self.proto_layer.labels
|
||||||
batch_loss = self.loss(out, y, plabels)
|
batch_loss = self.loss(out, y, plabels)
|
||||||
loss = batch_loss.sum()
|
train_loss = batch_loss.sum()
|
||||||
return loss
|
self.log("train_loss", train_loss)
|
||||||
|
return train_loss
|
||||||
def conditional_distribution(self, distances):
|
|
||||||
"""Conditional distribution of distances."""
|
|
||||||
if self._conditional_distribution is None:
|
|
||||||
raise ValueError("Conditional distribution is not set.")
|
|
||||||
return self._conditional_distribution(distances)
|
|
||||||
|
|
||||||
|
|
||||||
class SLVQ(ProbabilisticLVQ):
|
class SLVQ(ProbabilisticLVQ):
|
||||||
"""Soft Learning Vector Quantization."""
|
"""Soft Learning Vector Quantization."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# Default hparams
|
|
||||||
self.hparams.setdefault("variance", 1.0)
|
|
||||||
variance = self.hparams.get("variance")
|
|
||||||
|
|
||||||
self._conditional_distribution = GaussianPrior(variance)
|
|
||||||
self.loss = LossLayer(nllr_loss)
|
self.loss = LossLayer(nllr_loss)
|
||||||
|
|
||||||
|
|
||||||
class RSLVQ(ProbabilisticLVQ):
|
class RSLVQ(ProbabilisticLVQ):
|
||||||
"""Robust Soft Learning Vector Quantization."""
|
"""Robust Soft Learning Vector Quantization."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# Default hparams
|
|
||||||
self.hparams.setdefault("variance", 1.0)
|
|
||||||
variance = self.hparams.get("variance")
|
|
||||||
|
|
||||||
self._conditional_distribution = GaussianPrior(variance)
|
|
||||||
self.loss = LossLayer(rslvq_loss)
|
self.loss = LossLayer(rslvq_loss)
|
||||||
|
|
||||||
|
|
||||||
@ -111,19 +79,14 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
|||||||
|
|
||||||
TODO: Use Backbone LVQ instead
|
TODO: Use Backbone LVQ instead
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.conditional_distribution = RankScaledGaussianPrior(
|
||||||
# Default hparams
|
self.hparams.lambd)
|
||||||
self.hparams.setdefault("lambda", 1.0)
|
|
||||||
lam = self.hparams.get("lambda", 1.0)
|
|
||||||
|
|
||||||
self.conditional_distribution = RankScaledGaussianPrior(lam)
|
|
||||||
self.loss = torch.nn.KLDivLoss()
|
self.loss = torch.nn.KLDivLoss()
|
||||||
|
|
||||||
# FIXME
|
# FIXME
|
||||||
# def training_step(self, batch, batch_idx):
|
# def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
# x, y = batch
|
# x, y = batch
|
||||||
# y_pred = self(x)
|
# y_pred = self(x)
|
||||||
# batch_loss = self.loss(y_pred, y)
|
# batch_loss = self.loss(y_pred, y)
|
@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from prototorch.core.competitions import wtac
|
|
||||||
from prototorch.core.distances import squared_euclidean_distance
|
|
||||||
from prototorch.core.losses import NeuralGasEnergy
|
|
||||||
|
|
||||||
|
from ..core.competitions import wtac
|
||||||
|
from ..core.distances import squared_euclidean_distance
|
||||||
|
from ..core.losses import NeuralGasEnergy
|
||||||
|
from ..nn.wrappers import LambdaLayer
|
||||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
||||||
from .callbacks import GNGCallback
|
from .callbacks import GNGCallback
|
||||||
from .extras import ConnectionTopology
|
from .extras import ConnectionTopology
|
||||||
@ -17,8 +18,6 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
|||||||
TODO Allow non-2D grids
|
TODO Allow non-2D grids
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_grid: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
h, w = hparams.get("shape")
|
h, w = hparams.get("shape")
|
||||||
# Ignore `num_prototypes`
|
# Ignore `num_prototypes`
|
||||||
@ -35,7 +34,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
|||||||
|
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
x, y = torch.arange(h), torch.arange(w)
|
x, y = torch.arange(h), torch.arange(w)
|
||||||
grid = torch.stack(torch.meshgrid(x, y, indexing="ij"), dim=-1)
|
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
|
||||||
self.register_buffer("_grid", grid)
|
self.register_buffer("_grid", grid)
|
||||||
self._sigma = self.hparams.sigma
|
self._sigma = self.hparams.sigma
|
||||||
self._lr = self.hparams.lr
|
self._lr = self.hparams.lr
|
||||||
@ -54,16 +53,14 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
|||||||
grid = self._grid.view(-1, 2)
|
grid = self._grid.view(-1, 2)
|
||||||
gd = squared_euclidean_distance(wp, grid)
|
gd = squared_euclidean_distance(wp, grid)
|
||||||
nh = torch.exp(-gd / self._sigma**2)
|
nh = torch.exp(-gd / self._sigma**2)
|
||||||
protos = self.proto_layer()
|
protos = self.proto_layer.components
|
||||||
diff = x.unsqueeze(dim=1) - protos
|
diff = x.unsqueeze(dim=1) - protos
|
||||||
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
||||||
updated_protos = protos + delta.sum(dim=0)
|
updated_protos = protos + delta.sum(dim=0)
|
||||||
self.proto_layer.load_state_dict(
|
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||||
{"_components": updated_protos},
|
strict=False)
|
||||||
strict=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_training_epoch_end(self, training_step_outputs):
|
def training_epoch_end(self, training_step_outputs):
|
||||||
self._sigma = self.hparams.sigma * np.exp(
|
self._sigma = self.hparams.sigma * np.exp(
|
||||||
-self.current_epoch / self.trainer.max_epochs)
|
-self.current_epoch / self.trainer.max_epochs)
|
||||||
|
|
||||||
@ -72,7 +69,6 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
|||||||
|
|
||||||
|
|
||||||
class HeskesSOM(UnsupervisedPrototypeModel):
|
class HeskesSOM(UnsupervisedPrototypeModel):
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
@ -82,7 +78,6 @@ class HeskesSOM(UnsupervisedPrototypeModel):
|
|||||||
|
|
||||||
|
|
||||||
class NeuralGas(UnsupervisedPrototypeModel):
|
class NeuralGas(UnsupervisedPrototypeModel):
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
@ -90,13 +85,13 @@ class NeuralGas(UnsupervisedPrototypeModel):
|
|||||||
self.save_hyperparameters(hparams)
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
# Default hparams
|
# Default hparams
|
||||||
self.hparams.setdefault("age_limit", 10)
|
self.hparams.setdefault("agelimit", 10)
|
||||||
self.hparams.setdefault("lm", 1)
|
self.hparams.setdefault("lm", 1)
|
||||||
|
|
||||||
self.energy_layer = NeuralGasEnergy(lm=self.hparams["lm"])
|
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
|
||||||
self.topology_layer = ConnectionTopology(
|
self.topology_layer = ConnectionTopology(
|
||||||
agelimit=self.hparams["age_limit"],
|
agelimit=self.hparams.agelimit,
|
||||||
num_prototypes=self.hparams["num_prototypes"],
|
num_prototypes=self.hparams.num_prototypes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx):
|
||||||
@ -109,10 +104,12 @@ class NeuralGas(UnsupervisedPrototypeModel):
|
|||||||
self.log("loss", loss)
|
self.log("loss", loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
# def training_epoch_end(self, training_step_outputs):
|
||||||
|
# print(f"{self.trainer.lr_schedulers}")
|
||||||
|
# print(f"{self.trainer.lr_schedulers[0]['scheduler'].optimizer}")
|
||||||
|
|
||||||
|
|
||||||
class GrowingNeuralGas(NeuralGas):
|
class GrowingNeuralGas(NeuralGas):
|
||||||
errors: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
@ -121,10 +118,7 @@ class GrowingNeuralGas(NeuralGas):
|
|||||||
self.hparams.setdefault("insert_reduction", 0.1)
|
self.hparams.setdefault("insert_reduction", 0.1)
|
||||||
self.hparams.setdefault("insert_freq", 10)
|
self.hparams.setdefault("insert_freq", 10)
|
||||||
|
|
||||||
errors = torch.zeros(
|
errors = torch.zeros(self.hparams.num_prototypes, device=self.device)
|
||||||
self.hparams["num_prototypes"],
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.register_buffer("errors", errors)
|
self.register_buffer("errors", errors)
|
||||||
|
|
||||||
def training_step(self, train_batch, _batch_idx):
|
def training_step(self, train_batch, _batch_idx):
|
||||||
@ -139,7 +133,7 @@ class GrowingNeuralGas(NeuralGas):
|
|||||||
dp = d * mask
|
dp = d * mask
|
||||||
|
|
||||||
self.errors += torch.sum(dp * dp)
|
self.errors += torch.sum(dp * dp)
|
||||||
self.errors *= self.hparams["step_reduction"]
|
self.errors *= self.hparams.step_reduction
|
||||||
|
|
||||||
self.topology_layer(d)
|
self.topology_layer(d)
|
||||||
self.log("loss", loss)
|
self.log("loss", loss)
|
||||||
@ -147,8 +141,6 @@ class GrowingNeuralGas(NeuralGas):
|
|||||||
|
|
||||||
def configure_callbacks(self):
|
def configure_callbacks(self):
|
||||||
return [
|
return [
|
||||||
GNGCallback(
|
GNGCallback(reduction=self.hparams.insert_reduction,
|
||||||
reduction=self.hparams["insert_reduction"],
|
freq=self.hparams.insert_freq)
|
||||||
freq=self.hparams["insert_freq"],
|
|
||||||
)
|
|
||||||
]
|
]
|
@ -1,28 +1,20 @@
|
|||||||
"""Visualization Callbacks."""
|
"""Visualization Callbacks."""
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Sized
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from prototorch.utils.colors import get_colors, get_legend_handles
|
|
||||||
from prototorch.utils.utils import mesh2d
|
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from ..utils.utils import mesh2d
|
||||||
|
|
||||||
|
|
||||||
class Vis2DAbstract(pl.Callback):
|
class Vis2DAbstract(pl.Callback):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data=None,
|
data,
|
||||||
title="Prototype Visualization",
|
title="Prototype Visualization",
|
||||||
cmap="viridis",
|
cmap="viridis",
|
||||||
xlabel="Data dimension 1",
|
|
||||||
ylabel="Data dimension 2",
|
|
||||||
legend_labels=None,
|
|
||||||
border=0.1,
|
border=0.1,
|
||||||
resolution=100,
|
resolution=100,
|
||||||
flatten_data=True,
|
flatten_data=True,
|
||||||
@ -35,36 +27,24 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
block=False):
|
block=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if data:
|
if isinstance(data, Dataset):
|
||||||
if isinstance(data, Dataset):
|
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
||||||
if isinstance(data, Sized):
|
elif isinstance(data, torch.utils.data.DataLoader):
|
||||||
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
x = torch.tensor([])
|
||||||
else:
|
y = torch.tensor([])
|
||||||
# TODO: Add support for non-sized datasets
|
for x_b, y_b in data:
|
||||||
raise NotImplementedError(
|
x = torch.cat([x, x_b])
|
||||||
"Data must be a dataset with a __len__ method.")
|
y = torch.cat([y, y_b])
|
||||||
elif isinstance(data, DataLoader):
|
|
||||||
x = torch.tensor([])
|
|
||||||
y = torch.tensor([])
|
|
||||||
for x_b, y_b in data:
|
|
||||||
x = torch.cat([x, x_b])
|
|
||||||
y = torch.cat([y, y_b])
|
|
||||||
else:
|
|
||||||
x, y = data
|
|
||||||
|
|
||||||
if flatten_data:
|
|
||||||
x = x.reshape(len(x), -1)
|
|
||||||
|
|
||||||
self.x_train = x
|
|
||||||
self.y_train = y
|
|
||||||
else:
|
else:
|
||||||
self.x_train = None
|
x, y = data
|
||||||
self.y_train = None
|
|
||||||
|
if flatten_data:
|
||||||
|
x = x.reshape(len(x), -1)
|
||||||
|
|
||||||
|
self.x_train = x
|
||||||
|
self.y_train = y
|
||||||
|
|
||||||
self.title = title
|
self.title = title
|
||||||
self.xlabel = xlabel
|
|
||||||
self.ylabel = ylabel
|
|
||||||
self.legend_labels = legend_labels
|
|
||||||
self.fig = plt.figure(self.title)
|
self.fig = plt.figure(self.title)
|
||||||
self.cmap = cmap
|
self.cmap = cmap
|
||||||
self.border = border
|
self.border = border
|
||||||
@ -83,12 +63,14 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setup_ax(self):
|
def setup_ax(self, xlabel=None, ylabel=None):
|
||||||
ax = self.fig.gca()
|
ax = self.fig.gca()
|
||||||
ax.cla()
|
ax.cla()
|
||||||
ax.set_title(self.title)
|
ax.set_title(self.title)
|
||||||
ax.set_xlabel(self.xlabel)
|
if xlabel:
|
||||||
ax.set_ylabel(self.ylabel)
|
ax.set_xlabel("Data dimension 1")
|
||||||
|
if ylabel:
|
||||||
|
ax.set_ylabel("Data dimension 2")
|
||||||
if self.axis_off:
|
if self.axis_off:
|
||||||
ax.axis("off")
|
ax.axis("off")
|
||||||
return ax
|
return ax
|
||||||
@ -131,47 +113,60 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
else:
|
else:
|
||||||
plt.show(block=self.block)
|
plt.show(block=self.block)
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
|
||||||
if not self.precheck(trainer):
|
|
||||||
return True
|
|
||||||
self.visualize(pl_module)
|
|
||||||
self.log_and_display(trainer, pl_module)
|
|
||||||
|
|
||||||
def on_train_end(self, trainer, pl_module):
|
def on_train_end(self, trainer, pl_module):
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
def visualize(self, pl_module):
|
|
||||||
raise NotImplementedError
|
class Vis2D(Vis2DAbstract):
|
||||||
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
|
x_train, y_train = self.x_train, self.y_train
|
||||||
|
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||||
|
ylabel="Data dimension 2")
|
||||||
|
self.plot_data(ax, x_train, y_train)
|
||||||
|
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
|
||||||
|
mesh_input = torch.from_numpy(mesh_input).type_as(x_train)
|
||||||
|
y_pred = pl_module.predict(mesh_input)
|
||||||
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
|
|
||||||
class VisGLVQ2D(Vis2DAbstract):
|
class VisGLVQ2D(Vis2DAbstract):
|
||||||
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
def visualize(self, pl_module):
|
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
ax = self.setup_ax()
|
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||||
|
ylabel="Data dimension 2")
|
||||||
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
if x_train is not None:
|
x = np.vstack((x_train, protos))
|
||||||
self.plot_data(ax, x_train, y_train)
|
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||||
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
|
|
||||||
self.border, self.resolution)
|
|
||||||
else:
|
|
||||||
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
|
||||||
_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||||
y_pred = pl_module.predict(mesh_input)
|
y_pred = pl_module.predict(mesh_input)
|
||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
|
|
||||||
class VisSiameseGLVQ2D(Vis2DAbstract):
|
class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||||
|
|
||||||
def __init__(self, *args, map_protos=True, **kwargs):
|
def __init__(self, *args, map_protos=True, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.map_protos = map_protos
|
self.map_protos = map_protos
|
||||||
|
|
||||||
def visualize(self, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
@ -198,42 +193,18 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
|
self.log_and_display(trainer, pl_module)
|
||||||
class VisGMLVQ2D(Vis2DAbstract):
|
|
||||||
|
|
||||||
def __init__(self, *args, ev_proj=True, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ev_proj = ev_proj
|
|
||||||
|
|
||||||
def visualize(self, pl_module):
|
|
||||||
protos = pl_module.prototypes
|
|
||||||
plabels = pl_module.prototype_labels
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
|
||||||
device = pl_module.device
|
|
||||||
omega = pl_module._omega.detach()
|
|
||||||
lam = omega @ omega.T
|
|
||||||
u, _, _ = torch.pca_lowrank(lam, q=2)
|
|
||||||
with torch.no_grad():
|
|
||||||
x_train = torch.Tensor(x_train).to(device)
|
|
||||||
x_train = x_train @ u
|
|
||||||
x_train = x_train.cpu().detach()
|
|
||||||
if self.show_protos:
|
|
||||||
with torch.no_grad():
|
|
||||||
protos = torch.Tensor(protos).to(device)
|
|
||||||
protos = protos @ u
|
|
||||||
protos = protos.cpu().detach()
|
|
||||||
ax = self.setup_ax()
|
|
||||||
self.plot_data(ax, x_train, y_train)
|
|
||||||
if self.show_protos:
|
|
||||||
self.plot_protos(ax, protos, plabels)
|
|
||||||
|
|
||||||
|
|
||||||
class VisCBC2D(Vis2DAbstract):
|
class VisCBC2D(Vis2DAbstract):
|
||||||
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
def visualize(self, pl_module):
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
protos = pl_module.components
|
protos = pl_module.components
|
||||||
ax = self.setup_ax()
|
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||||
|
ylabel="Data dimension 2")
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, "w")
|
self.plot_protos(ax, protos, "w")
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
@ -245,15 +216,20 @@ class VisCBC2D(Vis2DAbstract):
|
|||||||
|
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
|
|
||||||
class VisNG2D(Vis2DAbstract):
|
class VisNG2D(Vis2DAbstract):
|
||||||
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
def visualize(self, pl_module):
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
||||||
|
|
||||||
ax = self.setup_ax()
|
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||||
|
ylabel="Data dimension 2")
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, "w")
|
self.plot_protos(ax, protos, "w")
|
||||||
|
|
||||||
@ -267,27 +243,10 @@ class VisNG2D(Vis2DAbstract):
|
|||||||
"k-",
|
"k-",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.log_and_display(trainer, pl_module)
|
||||||
class VisSpectralProtos(Vis2DAbstract):
|
|
||||||
|
|
||||||
def visualize(self, pl_module):
|
|
||||||
protos = pl_module.prototypes
|
|
||||||
plabels = pl_module.prototype_labels
|
|
||||||
ax = self.setup_ax()
|
|
||||||
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
|
|
||||||
for p, pl in zip(protos, plabels):
|
|
||||||
ax.plot(p, c=colors[int(pl)])
|
|
||||||
if self.legend_labels:
|
|
||||||
handles = get_legend_handles(
|
|
||||||
colors,
|
|
||||||
self.legend_labels,
|
|
||||||
marker="lines",
|
|
||||||
)
|
|
||||||
ax.legend(handles=handles)
|
|
||||||
|
|
||||||
|
|
||||||
class VisImgComp(Vis2DAbstract):
|
class VisImgComp(Vis2DAbstract):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
*args,
|
*args,
|
||||||
random_data=0,
|
random_data=0,
|
||||||
@ -303,45 +262,32 @@ class VisImgComp(Vis2DAbstract):
|
|||||||
self.add_embedding = add_embedding
|
self.add_embedding = add_embedding
|
||||||
self.embedding_data = embedding_data
|
self.embedding_data = embedding_data
|
||||||
|
|
||||||
def on_train_start(self, _, pl_module):
|
def on_train_start(self, trainer, pl_module):
|
||||||
if isinstance(pl_module.logger, TensorBoardLogger):
|
tb = pl_module.logger.experiment
|
||||||
tb = pl_module.logger.experiment
|
if self.add_embedding:
|
||||||
|
ind = np.random.choice(len(self.x_train),
|
||||||
|
size=self.embedding_data,
|
||||||
|
replace=False)
|
||||||
|
data = self.x_train[ind]
|
||||||
|
# print(f"{data.shape=}")
|
||||||
|
# print(f"{self.y_train[ind].shape=}")
|
||||||
|
tb.add_embedding(data.view(len(ind), -1),
|
||||||
|
label_img=data,
|
||||||
|
global_step=None,
|
||||||
|
tag="Data Embedding",
|
||||||
|
metadata=self.y_train[ind],
|
||||||
|
metadata_header=None)
|
||||||
|
|
||||||
# Add embedding
|
if self.random_data:
|
||||||
if self.add_embedding:
|
ind = np.random.choice(len(self.x_train),
|
||||||
if self.x_train is not None and self.y_train is not None:
|
size=self.random_data,
|
||||||
ind = np.random.choice(len(self.x_train),
|
replace=False)
|
||||||
size=self.embedding_data,
|
data = self.x_train[ind]
|
||||||
replace=False)
|
grid = torchvision.utils.make_grid(data, nrow=self.num_columns)
|
||||||
data = self.x_train[ind]
|
tb.add_image(tag="Data",
|
||||||
tb.add_embedding(data.view(len(ind), -1),
|
img_tensor=grid,
|
||||||
label_img=data,
|
global_step=None,
|
||||||
global_step=None,
|
dataformats=self.dataformats)
|
||||||
tag="Data Embedding",
|
|
||||||
metadata=self.y_train[ind],
|
|
||||||
metadata_header=None)
|
|
||||||
else:
|
|
||||||
raise ValueError("No data for add embedding flag")
|
|
||||||
|
|
||||||
# Random Data
|
|
||||||
if self.random_data:
|
|
||||||
if self.x_train is not None:
|
|
||||||
ind = np.random.choice(len(self.x_train),
|
|
||||||
size=self.random_data,
|
|
||||||
replace=False)
|
|
||||||
data = self.x_train[ind]
|
|
||||||
grid = torchvision.utils.make_grid(data,
|
|
||||||
nrow=self.num_columns)
|
|
||||||
tb.add_image(tag="Data",
|
|
||||||
img_tensor=grid,
|
|
||||||
global_step=None,
|
|
||||||
dataformats=self.dataformats)
|
|
||||||
else:
|
|
||||||
raise ValueError("No data for random data flag")
|
|
||||||
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
f"TensorBoardLogger is required, got {type(pl_module.logger)}")
|
|
||||||
|
|
||||||
def add_to_tensorboard(self, trainer, pl_module):
|
def add_to_tensorboard(self, trainer, pl_module):
|
||||||
tb = pl_module.logger.experiment
|
tb = pl_module.logger.experiment
|
||||||
@ -355,9 +301,14 @@ class VisImgComp(Vis2DAbstract):
|
|||||||
dataformats=self.dataformats,
|
dataformats=self.dataformats,
|
||||||
)
|
)
|
||||||
|
|
||||||
def visualize(self, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
if self.show:
|
if self.show:
|
||||||
components = pl_module.components
|
components = pl_module.components
|
||||||
grid = torchvision.utils.make_grid(components,
|
grid = torchvision.utils.make_grid(components,
|
||||||
nrow=self.num_columns)
|
nrow=self.num_columns)
|
||||||
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
||||||
|
|
||||||
|
self.log_and_display(trainer, pl_module)
|
@ -1,90 +0,0 @@
|
|||||||
|
|
||||||
[project]
|
|
||||||
name = "prototorch-models"
|
|
||||||
version = "0.7.1"
|
|
||||||
description = "Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning."
|
|
||||||
authors = [
|
|
||||||
{ name = "Jensun Ravichandran", email = "jjensun@gmail.com" },
|
|
||||||
{ name = "Alexander Engelsberger", email = "engelsbe@hs-mittweida.de" },
|
|
||||||
]
|
|
||||||
dependencies = ["lightning>=2.0.0", "prototorch>=0.7.5"]
|
|
||||||
requires-python = ">=3.8"
|
|
||||||
readme = "README.md"
|
|
||||||
license = { text = "MIT" }
|
|
||||||
classifiers = [
|
|
||||||
"Development Status :: 2 - Pre-Alpha",
|
|
||||||
"Environment :: Plugins",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"Intended Audience :: Education",
|
|
||||||
"Intended Audience :: Science/Research",
|
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Natural Language :: English",
|
|
||||||
"Operating System :: OS Independent",
|
|
||||||
"Programming Language :: Python :: 3",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Programming Language :: Python :: 3.11",
|
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
||||||
"Topic :: Software Development :: Libraries",
|
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.urls]
|
|
||||||
Homepage = "https://github.com/si-cim/prototorch_models"
|
|
||||||
Downloads = "https://github.com/si-cim/prototorch_models.git"
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
dev = ["bumpversion", "pre-commit", "yapf", "toml"]
|
|
||||||
examples = ["matplotlib", "scikit-learn"]
|
|
||||||
ci = ["pytest", "pre-commit"]
|
|
||||||
docs = [
|
|
||||||
"recommonmark",
|
|
||||||
"nbsphinx",
|
|
||||||
"sphinx",
|
|
||||||
"sphinx_rtd_theme",
|
|
||||||
"sphinxcontrib-bibtex",
|
|
||||||
"sphinxcontrib-katex",
|
|
||||||
"ipykernel",
|
|
||||||
]
|
|
||||||
all = [
|
|
||||||
"bumpversion",
|
|
||||||
"pre-commit",
|
|
||||||
"yapf",
|
|
||||||
"toml",
|
|
||||||
"pytest",
|
|
||||||
"matplotlib",
|
|
||||||
"scikit-learn",
|
|
||||||
"recommonmark",
|
|
||||||
"nbsphinx",
|
|
||||||
"sphinx",
|
|
||||||
"sphinx_rtd_theme",
|
|
||||||
"sphinxcontrib-bibtex",
|
|
||||||
"sphinxcontrib-katex",
|
|
||||||
"ipykernel",
|
|
||||||
]
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["setuptools>=61", "wheel"]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
||||||
|
|
||||||
[tool.yapf]
|
|
||||||
based_on_style = "pep8"
|
|
||||||
spaces_before_comment = 2
|
|
||||||
split_before_logical_operator = true
|
|
||||||
|
|
||||||
[tool.pylint]
|
|
||||||
disable = ["too-many-arguments", "too-few-public-methods", "fixme"]
|
|
||||||
|
|
||||||
[tool.isort]
|
|
||||||
profile = "hug"
|
|
||||||
src_paths = ["isort", "test"]
|
|
||||||
multi_line_output = 3
|
|
||||||
include_trailing_comma = true
|
|
||||||
force_grid_wrap = 3
|
|
||||||
use_parentheses = true
|
|
||||||
line_length = 79
|
|
||||||
|
|
||||||
[tool.mypy]
|
|
||||||
explicit_package_bases = true
|
|
||||||
namespace_packages = true
|
|
8
setup.cfg
Normal file
8
setup.cfg
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
[isort]
|
||||||
|
profile = hug
|
||||||
|
src_paths = isort, test
|
||||||
|
|
||||||
|
[yapf]
|
||||||
|
based_on_style = pep8
|
||||||
|
spaces_before_comment = 2
|
||||||
|
split_before_logical_operator = true
|
93
setup.py
Normal file
93
setup.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
"""
|
||||||
|
|
||||||
|
######
|
||||||
|
# # ##### #### ##### #### ##### #### ##### #### # #
|
||||||
|
# # # # # # # # # # # # # # # # # #
|
||||||
|
###### # # # # # # # # # # # # # ######
|
||||||
|
# ##### # # # # # # # # ##### # # #
|
||||||
|
# # # # # # # # # # # # # # # # #
|
||||||
|
# # # #### # #### # #### # # #### # #Plugin
|
||||||
|
|
||||||
|
ProtoTorch models Plugin Package
|
||||||
|
"""
|
||||||
|
from pkg_resources import safe_name
|
||||||
|
from setuptools import find_namespace_packages, setup
|
||||||
|
|
||||||
|
PLUGIN_NAME = "models"
|
||||||
|
|
||||||
|
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
||||||
|
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
||||||
|
|
||||||
|
with open("README.md", "r") as fh:
|
||||||
|
long_description = fh.read()
|
||||||
|
|
||||||
|
INSTALL_REQUIRES = [
|
||||||
|
"prototorch>=0.6.0",
|
||||||
|
"pytorch_lightning>=1.3.5",
|
||||||
|
"torchmetrics",
|
||||||
|
]
|
||||||
|
CLI = [
|
||||||
|
"jsonargparse",
|
||||||
|
]
|
||||||
|
DEV = [
|
||||||
|
"bumpversion",
|
||||||
|
"pre-commit",
|
||||||
|
]
|
||||||
|
DOCS = [
|
||||||
|
"recommonmark",
|
||||||
|
"sphinx",
|
||||||
|
"nbsphinx",
|
||||||
|
"sphinx_rtd_theme",
|
||||||
|
"sphinxcontrib-katex",
|
||||||
|
"sphinxcontrib-bibtex",
|
||||||
|
]
|
||||||
|
EXAMPLES = [
|
||||||
|
"matplotlib",
|
||||||
|
"scikit-learn",
|
||||||
|
]
|
||||||
|
TESTS = [
|
||||||
|
"codecov",
|
||||||
|
"pytest",
|
||||||
|
]
|
||||||
|
ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=safe_name("prototorch_" + PLUGIN_NAME),
|
||||||
|
version="0.2.0",
|
||||||
|
description="Pre-packaged prototype-based "
|
||||||
|
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
author="Alexander Engelsberger",
|
||||||
|
author_email="engelsbe@hs-mittweida.de",
|
||||||
|
url=PROJECT_URL,
|
||||||
|
download_url=DOWNLOAD_URL,
|
||||||
|
license="MIT",
|
||||||
|
python_requires=">=3.9",
|
||||||
|
install_requires=INSTALL_REQUIRES,
|
||||||
|
extras_require={
|
||||||
|
"dev": DEV,
|
||||||
|
"examples": EXAMPLES,
|
||||||
|
"tests": TESTS,
|
||||||
|
"all": ALL,
|
||||||
|
},
|
||||||
|
classifiers=[
|
||||||
|
"Development Status :: 2 - Pre-Alpha",
|
||||||
|
"Environment :: Plugins",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Intended Audience :: Education",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Natural Language :: English",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Software Development :: Libraries",
|
||||||
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
|
],
|
||||||
|
entry_points={
|
||||||
|
"prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}"
|
||||||
|
},
|
||||||
|
packages=find_namespace_packages(include=["prototorch.*"]),
|
||||||
|
zip_safe=False,
|
||||||
|
)
|
14
tests/test_.py
Normal file
14
tests/test_.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
"""prototorch.models test suite."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestDummy(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_dummy(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
@ -1,27 +1,11 @@
|
|||||||
#! /bin/bash
|
#! /bin/bash
|
||||||
|
|
||||||
|
|
||||||
# Read Flags
|
|
||||||
gpu=0
|
|
||||||
while [ -n "$1" ]; do
|
|
||||||
case "$1" in
|
|
||||||
--gpu) gpu=1;;
|
|
||||||
-g) gpu=1;;
|
|
||||||
*) path=$1;;
|
|
||||||
esac
|
|
||||||
shift
|
|
||||||
done
|
|
||||||
|
|
||||||
python --version
|
|
||||||
echo "Using GPU: " $gpu
|
|
||||||
|
|
||||||
# Loop
|
|
||||||
failed=0
|
failed=0
|
||||||
|
|
||||||
for example in $(find $path -maxdepth 1 -name "*.py")
|
for example in $(find $1 -maxdepth 1 -name "*.py")
|
||||||
do
|
do
|
||||||
echo -n "$x" $example '... '
|
echo -n "$x" $example '... '
|
||||||
export DISPLAY= && python $example --fast_dev_run 1 --gpus $gpu &> run_log.txt
|
export DISPLAY= && python $example --fast_dev_run 1 &> run_log.txt
|
||||||
if [[ $? -ne 0 ]]; then
|
if [[ $? -ne 0 ]]; then
|
||||||
echo "FAILED!!"
|
echo "FAILED!!"
|
||||||
cat run_log.txt
|
cat run_log.txt
|
||||||
|
@ -1,193 +0,0 @@
|
|||||||
"""prototorch.models test suite."""
|
|
||||||
|
|
||||||
import prototorch.models
|
|
||||||
|
|
||||||
|
|
||||||
def test_glvq_model_build():
|
|
||||||
model = prototorch.models.GLVQ(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_glvq1_model_build():
|
|
||||||
model = prototorch.models.GLVQ1(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_glvq21_model_build():
|
|
||||||
model = prototorch.models.GLVQ1(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gmlvq_model_build():
|
|
||||||
model = prototorch.models.GMLVQ(
|
|
||||||
{
|
|
||||||
"distribution": (3, 2),
|
|
||||||
"input_dim": 2,
|
|
||||||
"latent_dim": 2,
|
|
||||||
},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_grlvq_model_build():
|
|
||||||
model = prototorch.models.GRLVQ(
|
|
||||||
{
|
|
||||||
"distribution": (3, 2),
|
|
||||||
"input_dim": 2,
|
|
||||||
},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gtlvq_model_build():
|
|
||||||
model = prototorch.models.GTLVQ(
|
|
||||||
{
|
|
||||||
"distribution": (3, 2),
|
|
||||||
"input_dim": 4,
|
|
||||||
"latent_dim": 2,
|
|
||||||
},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_lgmlvq_model_build():
|
|
||||||
model = prototorch.models.LGMLVQ(
|
|
||||||
{
|
|
||||||
"distribution": (3, 2),
|
|
||||||
"input_dim": 4,
|
|
||||||
"latent_dim": 2,
|
|
||||||
},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_glvq_model_build():
|
|
||||||
model = prototorch.models.ImageGLVQ(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(16),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_gmlvq_model_build():
|
|
||||||
model = prototorch.models.ImageGMLVQ(
|
|
||||||
{
|
|
||||||
"distribution": (3, 2),
|
|
||||||
"input_dim": 16,
|
|
||||||
"latent_dim": 2,
|
|
||||||
},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(16),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_gtlvq_model_build():
|
|
||||||
model = prototorch.models.ImageGMLVQ(
|
|
||||||
{
|
|
||||||
"distribution": (3, 2),
|
|
||||||
"input_dim": 16,
|
|
||||||
"latent_dim": 2,
|
|
||||||
},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(16),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_siamese_glvq_model_build():
|
|
||||||
model = prototorch.models.SiameseGLVQ(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(4),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_siamese_gmlvq_model_build():
|
|
||||||
model = prototorch.models.SiameseGMLVQ(
|
|
||||||
{
|
|
||||||
"distribution": (3, 2),
|
|
||||||
"input_dim": 4,
|
|
||||||
"latent_dim": 2,
|
|
||||||
},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(4),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_siamese_gtlvq_model_build():
|
|
||||||
model = prototorch.models.SiameseGTLVQ(
|
|
||||||
{
|
|
||||||
"distribution": (3, 2),
|
|
||||||
"input_dim": 4,
|
|
||||||
"latent_dim": 2,
|
|
||||||
},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(4),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_knn_model_build():
|
|
||||||
train_ds = prototorch.datasets.Iris(dims=[0, 2])
|
|
||||||
model = prototorch.models.KNN(dict(k=3), data=train_ds)
|
|
||||||
|
|
||||||
|
|
||||||
def test_lvq1_model_build():
|
|
||||||
model = prototorch.models.LVQ1(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_lvq21_model_build():
|
|
||||||
model = prototorch.models.LVQ21(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_median_lvq_model_build():
|
|
||||||
model = prototorch.models.MedianLVQ(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_celvq_model_build():
|
|
||||||
model = prototorch.models.CELVQ(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_rslvq_model_build():
|
|
||||||
model = prototorch.models.RSLVQ(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_slvq_model_build():
|
|
||||||
model = prototorch.models.SLVQ(
|
|
||||||
{"distribution": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_growing_neural_gas_model_build():
|
|
||||||
model = prototorch.models.GrowingNeuralGas(
|
|
||||||
{"num_prototypes": 5},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_kohonen_som_model_build():
|
|
||||||
model = prototorch.models.KohonenSOM(
|
|
||||||
{"shape": (3, 2)},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_neural_gas_model_build():
|
|
||||||
model = prototorch.models.NeuralGas(
|
|
||||||
{"num_prototypes": 5},
|
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
|
||||||
)
|
|
Loading…
Reference in New Issue
Block a user