30 Commits

Author SHA1 Message Date
Alexander Engelsberger
d6629c8792 build: bump version 0.4.1 → 0.5.0 2022-04-27 10:28:06 +02:00
Alexander Engelsberger
ef65bd3789 chore: update prototorch dependency 2022-04-27 09:50:48 +02:00
Alexander Engelsberger
d096eba2c9 chore: update pytorch lightning dependency 2022-04-27 09:39:00 +02:00
Alexander Engelsberger
dd34c57e2e ci: fix github action python version 2022-04-27 09:30:07 +02:00
Alexander Engelsberger
5911f4dd90 chore: fix errors for pytorch_lightning>1.6 2022-04-27 09:25:42 +02:00
Alexander Engelsberger
dbfe315f4f ci: add python 3.10 to tests 2022-04-27 09:24:34 +02:00
Jensun Ravichandran
9c90c902dc fix: correct typo 2022-04-04 21:54:04 +02:00
Jensun Ravichandran
7d3f59e54b test: add unit tests 2022-03-30 15:12:33 +02:00
Jensun Ravichandran
9da47b1dba fix: CBC example works again 2022-03-30 15:10:06 +02:00
Alexander Engelsberger
41f0e77fc9 fix: siameseGLVQ checks requires_grad of backbone
Necessary for different optimizer runs
2022-03-29 17:08:40 +02:00
Jensun Ravichandran
fab786a07e fix: rename hparam output_dimlatent_dim in SiameseGMLVQ 2022-03-29 15:24:42 +02:00
Jensun Ravichandran
40bd7ed380 docs: update tutorial notebook 2022-03-29 15:04:05 +02:00
Jensun Ravichandran
4941c2b89d feat: data argument optional in some visualizers 2022-03-29 11:26:22 +02:00
Jensun Ravichandran
ce14dec7e9 feat: add VisSpectralProtos 2022-03-10 15:24:44 +01:00
Jensun Ravichandran
b31c8cc707 feat: add xlabel and ylabel arguments to visualizers 2022-03-09 13:59:19 +01:00
Jensun Ravichandran
e21e6c7e02 docs: update tutorial notebook 2022-02-15 14:38:53 +01:00
Jensun Ravichandran
dd696ea1e0 fix: update hparams.distribution as it changes during training 2022-02-02 21:53:03 +01:00
Jensun Ravichandran
15e7232747 fix: ignore prototype_win_ratios by loading with strict=False 2022-02-02 21:52:01 +01:00
Jensun Ravichandran
197b728c63 feat: add visualize method to visualization callbacks
All visualization callbacks now contain a `visualize` method that takes an
appropriate PyTorchLightning Module and visualizes it without the need for a
Trainer. This is to encourage users to perform one-off visualizations after
training.
2022-02-02 21:45:44 +01:00
Jensun Ravichandran
98892afee0 chore: add example for saving/loading models from checkpoints 2022-02-02 19:02:26 +01:00
Alexander Engelsberger
d5855dbe97 fix: GLVQ can now be restored from checkpoint 2022-02-02 16:17:11 +01:00
Alexander Engelsberger
75a39f5b03 build: bump version 0.4.0 → 0.4.1 2022-01-11 18:29:55 +01:00
Alexander Engelsberger
1a0e697b27 Merge branch 'dev' into main 2022-01-11 18:29:32 +01:00
Alexander Engelsberger
1a17193b35 ci: add github actions (#16)
* chore: update pre-commit versions

* ci: remove old configurations

* ci: copy workflow from prototorch

* ci: run precommit for all files

* ci: add examples CPU test

* ci(test): failing example test

* ci: fix workflow definition

* ci(test): repeat failing example test

* ci: fix workflow definition

* ci(test): repeat failing example test II

* ci: fix test command

* ci: cleanup example test

* ci: remove travis badge
2022-01-11 18:28:50 +01:00
Alexander Engelsberger
aaa3c51e0a build: bump version 0.3.0 → 0.4.0 2021-12-09 15:58:16 +01:00
Jensun Ravichandran
62c5974a85 fix: correct typo in example script 2021-11-17 15:01:38 +01:00
Jensun Ravichandran
1d26226a2f fix(warning): specify dimension explicitly when calling softmin 2021-11-16 10:19:31 +01:00
Christoph
4232d0ed2a fix: spelling issues for previous commits 2021-11-15 11:43:39 +01:00
Christoph
a9edf06507 feat: ImageGTLVQ and SiameseGTLVQ with examples 2021-11-15 11:43:39 +01:00
Christoph
d3bb430104 feat: gtlvq with examples 2021-11-15 11:43:39 +01:00
39 changed files with 1337 additions and 782 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.3.0 current_version = 0.5.0
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@@ -1,15 +0,0 @@
# To validate the contents of your configuration file
# run the following command in the folder where the configuration file is located:
# codacy-analysis-cli validate-configuration --directory `pwd`
# To analyse, run:
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
---
engines:
pylintpython3:
exclude_paths:
- config/engines.yml
remark-lint:
exclude_paths:
- config/engines.yml
exclude_paths:
- 'tests/**'

View File

@@ -1,2 +0,0 @@
comment:
require_changes: yes

25
.github/workflows/examples.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
# Thi workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: examples
on:
push:
paths:
- 'examples/**.py'
jobs:
cpu:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Run examples
run: |
./tests/test_examples.sh examples/

75
.github/workflows/pythonapp.yml vendored Normal file
View File

@@ -0,0 +1,75 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: tests
on:
push:
pull_request:
branches: [ master ]
jobs:
style:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- uses: pre-commit/action@v2.0.3
compatibility:
needs: style
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
os: [ubuntu-latest, windows-latest]
exclude:
- os: windows-latest
python-version: "3.7"
- os: windows-latest
python-version: "3.8"
- os: windows-latest
python-version: "3.9"
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: |
pytest
publish_pypi:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs: compatibility
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
pip install wheel
- name: Build package
run: python setup.py sdist bdist_wheel
- name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -3,7 +3,7 @@
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1 rev: v4.1.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: end-of-file-fixer - id: end-of-file-fixer
@@ -18,19 +18,19 @@ repos:
- id: autoflake - id: autoflake
- repo: http://github.com/PyCQA/isort - repo: http://github.com/PyCQA/isort
rev: 5.9.3 rev: 5.10.1
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910-1 rev: v0.931
hooks: hooks:
- id: mypy - id: mypy
files: prototorch files: prototorch
additional_dependencies: [types-pkg_resources] additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.31.0 rev: v0.32.0
hooks: hooks:
- id: yapf - id: yapf
@@ -42,10 +42,9 @@ repos:
- id: python-check-blanket-noqa - id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v2.29.0 rev: v2.31.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade
args: [--py36-plus]
- repo: https://github.com/si-cim/gitlint - repo: https://github.com/si-cim/gitlint
rev: v0.15.2-unofficial rev: v0.15.2-unofficial

View File

@@ -1,44 +0,0 @@
dist: bionic
sudo: false
language: python
python:
- 3.9
- 3.8
- 3.7
- 3.6
cache:
directories:
- "$HOME/.cache/pip"
- "./tests/artifacts"
- "$HOME/datasets"
install:
- pip install git+git://github.com/si-cim/prototorch@dev --progress-bar off
- pip install .[all] --progress-bar off
script:
- coverage run -m pytest
- ./tests/test_examples.sh examples/
after_success:
- bash <(curl -s https://codecov.io/bash)
# Publish on PyPI
jobs:
include:
- stage: build
python: 3.9
script: echo "Starting Pypi build"
deploy:
provider: pypi
username: __token__
distributions: "sdist bdist_wheel"
password:
secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
on:
tags: true
skip_existing: true
# The password is encrypted with:
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
# See https://docs.travis-ci.com/user/deployment/pypi and
# https://github.com/travis-ci/travis.rb#installation
# for more details
# Note: The encrypt command does not work well in ZSH.

View File

@@ -1,6 +1,5 @@
# ProtoTorch Models # ProtoTorch Models
[![Build Status](https://api.travis-ci.com/si-cim/prototorch_models.svg?branch=main)](https://travis-ci.com/github/si-cim/prototorch_models)
[![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch_models?color=yellow&label=version)](https://github.com/si-cim/prototorch_models/releases) [![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch_models?color=yellow&label=version)](https://github.com/si-cim/prototorch_models/releases)
[![PyPI](https://img.shields.io/pypi/v/prototorch_models)](https://pypi.org/project/prototorch_models/) [![PyPI](https://img.shields.io/pypi/v/prototorch_models)](https://pypi.org/project/prototorch_models/)
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch_models)](https://github.com/si-cim/prototorch_models/blob/master/LICENSE) [![GitHub license](https://img.shields.io/github/license/si-cim/prototorch_models)](https://github.com/si-cim/prototorch_models/blob/master/LICENSE)

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.3.0" release = "0.5.0"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -2,223 +2,252 @@
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "7ac5eff0",
"metadata": {},
"source": [ "source": [
"# A short tutorial for the `prototorch.models` plugin" "# A short tutorial for the `prototorch.models` plugin"
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "beb83780",
"metadata": {},
"source": [ "source": [
"## Introduction" "## Introduction"
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "43b74278",
"metadata": {},
"source": [ "source": [
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework.\n", "This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
"\n", "\n",
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n", "[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
"\n", "\n",
"With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this." "With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this."
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "4e5d1fad",
"metadata": {},
"source": [ "source": [
"## Basics" "## Basics"
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "1244b66b",
"metadata": {},
"source": [ "source": [
"First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:" "First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:"
], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "dcb88e8a",
"metadata": {},
"outputs": [],
"source": [ "source": [
"import prototorch as pt\n", "import prototorch as pt\n",
"import pytorch_lightning as pl\n", "import pytorch_lightning as pl\n",
"import torch" "import torch"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "1adbe2f8",
"metadata": {},
"source": [ "source": [
"### Building Models" "### Building Models"
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "96663ab1",
"metadata": {},
"source": [ "source": [
"Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer." "Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer."
], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "819ba756",
"metadata": {},
"outputs": [],
"source": [ "source": [
"model = pt.models.GLVQ(\n", "model = pt.models.GLVQ(\n",
" hparams=dict(distribution=[1, 1, 1]),\n", " hparams=dict(distribution=[1, 1, 1]),\n",
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n", " prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
")" ")"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "1b37e97c",
"metadata": {},
"outputs": [],
"source": [ "source": [
"print(model)" "print(model)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "d2c86903",
"metadata": {},
"source": [ "source": [
"The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n", "The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n",
"\n", "\n",
"The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros." "The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros."
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "45806052",
"metadata": {},
"source": [ "source": [
"### Data" "### Data"
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "9d62c4c6",
"metadata": {},
"source": [ "source": [
"The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets." "The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets."
], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "504df02c",
"metadata": {},
"outputs": [],
"source": [ "source": [
"train_ds = pt.datasets.Iris(dims=[0, 2])" "train_ds = pt.datasets.Iris(dims=[0, 2])"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "3b8e7756",
"metadata": {},
"outputs": [],
"source": [ "source": [
"type(train_ds)" "type(train_ds)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "bce43afa",
"metadata": {},
"outputs": [],
"source": [ "source": [
"train_ds.data.shape, train_ds.targets.shape" "train_ds.data.shape, train_ds.targets.shape"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "26a83328",
"metadata": {},
"source": [ "source": [
"Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly." "Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly."
], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "67b80fbe",
"metadata": {},
"outputs": [],
"source": [ "source": [
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)" "train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "c1185f31",
"metadata": {},
"outputs": [],
"source": [ "source": [
"type(train_loader)" "type(train_loader)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "9b5a8963",
"metadata": {},
"outputs": [],
"source": [ "source": [
"x_batch, y_batch = next(iter(train_loader))\n", "x_batch, y_batch = next(iter(train_loader))\n",
"print(f\"{x_batch=}, {y_batch=}\")" "print(f\"{x_batch=}, {y_batch=}\")"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "dd492ee2",
"metadata": {},
"source": [ "source": [
"This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches." "This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches."
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "5176b055",
"metadata": {},
"source": [ "source": [
"### Training" "### Training"
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "46a7a506",
"metadata": {},
"source": [ "source": [
"If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented." "If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented."
], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "279e75b7",
"metadata": {},
"outputs": [],
"source": [ "source": [
"trainer = pl.Trainer(max_epochs=2, weights_summary=None)" "trainer = pl.Trainer(max_epochs=2, weights_summary=None)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "e496b492",
"metadata": {},
"outputs": [],
"source": [ "source": [
"trainer.fit(model, train_loader)" "trainer.fit(model, train_loader)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "497fbff6",
"metadata": {},
"source": [ "source": [
"### From data to a trained model - a very minimal example" "### From data to a trained model - a very minimal example"
], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "ab069c5d",
"metadata": {},
"outputs": [],
"source": [ "source": [
"train_ds = pt.datasets.Iris(dims=[0, 2])\n", "train_ds = pt.datasets.Iris(dims=[0, 2])\n",
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n", "train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n",
@@ -230,49 +259,239 @@
"\n", "\n",
"trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n", "trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n",
"trainer.fit(model, train_loader)" "trainer.fit(model, train_loader)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "30c71a93",
"metadata": {},
"source": [ "source": [
"## Advanced" "### Saving/Loading trained models"
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "f74ed2c1",
"metadata": {},
"source": [ "source": [
"### Initializing prototypes with a subset of a dataset (along with transformations)" "Pytorch Lightning can automatically checkpoint the model during various stages of training, but it also possible to manually save a checkpoint after training."
], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "3156658d",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = \"./checkpoints/glvq_iris.ckpt\"\n",
"trainer.save_checkpoint(ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1c34055",
"metadata": {},
"outputs": [],
"source": [
"loaded_model = pt.models.GLVQ.load_from_checkpoint(ckpt_path, strict=False)"
]
},
{
"cell_type": "markdown",
"id": "bbbb08e9",
"metadata": {},
"source": [
"### Visualizing decision boundaries in 2D"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53ca52dc",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds).visualize(loaded_model)"
]
},
{
"cell_type": "markdown",
"id": "8373531f",
"metadata": {},
"source": [
"### Saving/Loading trained weights"
]
},
{
"cell_type": "markdown",
"id": "937bc458",
"metadata": {},
"source": [
"In most cases, the checkpointing workflow is sufficient. In some cases however, one might want to only save the trained weights from the model. The disadvantage of this method is that the model has be re-created using compatible initialization parameters before the weights could be loaded."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f2035af",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = \"./checkpoints/glvq_iris_weights.pth\"\n",
"torch.save(model.state_dict(), ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1206021a",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.GLVQ(\n",
" dict(distribution=(3, 2)),\n",
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f2a4beb",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds, title=\"Before loading the weights\").visualize(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "528d2fc2",
"metadata": {},
"outputs": [],
"source": [
"torch.load(ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec817e6b",
"metadata": {},
"outputs": [],
"source": [
"model.load_state_dict(torch.load(ckpt_path), strict=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a208eab7",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds, title=\"After loading the weights\").visualize(model)"
]
},
{
"cell_type": "markdown",
"id": "f8de748f",
"metadata": {},
"source": [
"## Advanced"
]
},
{
"cell_type": "markdown",
"id": "53a64063",
"metadata": {},
"source": [
"### Warm-start a model with prototypes learned from another model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3177c277",
"metadata": {},
"outputs": [],
"source": [
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
"model = pt.models.SiameseGMLVQ(\n",
" dict(input_dim=2,\n",
" latent_dim=2,\n",
" distribution=(3, 2),\n",
" proto_lr=0.0001,\n",
" bb_lr=0.0001),\n",
" optimizer=torch.optim.Adam,\n",
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8baee9a2",
"metadata": {},
"outputs": [],
"source": [
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc203088",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
]
},
{
"cell_type": "markdown",
"id": "1f6a33a5",
"metadata": {},
"source": [
"### Initializing prototypes with a subset of a dataset (along with transformations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "946ce341",
"metadata": {},
"outputs": [],
"source": [ "source": [
"import prototorch as pt\n", "import prototorch as pt\n",
"import pytorch_lightning as pl\n", "import pytorch_lightning as pl\n",
"import torch\n", "import torch\n",
"from torchvision import transforms\n", "from torchvision import transforms\n",
"from torchvision.datasets import MNIST" "from torchvision.datasets import MNIST\n",
], "from torchvision.utils import make_grid"
"outputs": [], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "510d9bd4",
"metadata": {},
"outputs": [],
"source": [ "source": [
"from matplotlib import pyplot as plt" "from matplotlib import pyplot as plt"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "ea7c1228",
"metadata": {},
"outputs": [],
"source": [ "source": [
"train_ds = MNIST(\n", "train_ds = MNIST(\n",
" \"~/datasets\",\n", " \"~/datasets\",\n",
@@ -284,59 +503,87 @@
" transforms.ToTensor(),\n", " transforms.ToTensor(),\n",
" ]),\n", " ]),\n",
")" ")"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "1b9eaf5c",
"metadata": {},
"outputs": [],
"source": [ "source": [
"s = int(0.05 * len(train_ds))\n", "s = int(0.05 * len(train_ds))\n",
"init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])" "init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "8c32c9f2",
"metadata": {},
"outputs": [],
"source": [ "source": [
"init_ds" "init_ds"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "68a9a8b9",
"metadata": {},
"outputs": [],
"source": [ "source": [
"model = pt.models.ImageGLVQ(\n", "model = pt.models.ImageGLVQ(\n",
" dict(distribution=(10, 5)),\n", " dict(distribution=(10, 1)),\n",
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n", " prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
")" ")"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"source": [ "id": "6f23df86",
"plt.imshow(model.get_prototype_grid(num_columns=10))" "metadata": {},
],
"outputs": [], "outputs": [],
"metadata": {} "source": [
"plt.imshow(model.get_prototype_grid(num_columns=5))"
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "1c23c7b2",
"metadata": {},
"source": [
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30780927",
"metadata": {},
"outputs": [],
"source": [
"protos, plabels = pt.components.LabeledComponents(\n",
" distribution=(10, 5),\n",
" components_initializer=pt.initializers.SMCI(init_ds),\n",
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
")()\n",
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
]
},
{
"cell_type": "markdown",
"id": "4fa69f92",
"metadata": {},
"source": [ "source": [
"## FAQs" "## FAQs"
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "fa20f9ac",
"metadata": {},
"source": [ "source": [
"### How do I Retrieve the prototypes and their respective labels from the model?\n", "### How do I Retrieve the prototypes and their respective labels from the model?\n",
"\n", "\n",
@@ -351,11 +598,12 @@
"```python\n", "```python\n",
">>> model.prototype_labels\n", ">>> model.prototype_labels\n",
"```" "```"
], ]
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "ba8215bf",
"metadata": {},
"source": [ "source": [
"### How do I make inferences/predictions/recall with my trained model?\n", "### How do I make inferences/predictions/recall with my trained model?\n",
"\n", "\n",
@@ -370,13 +618,12 @@
"```python\n", "```python\n",
">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n", ">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n",
"```" "```"
], ]
"metadata": {}
} }
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@@ -390,7 +637,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.4" "version": "3.9.12"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -38,12 +38,10 @@ if __name__ == "__main__":
) )
# Callbacks # Callbacks
vis = pt.models.Visualize2DVoronoiCallback( vis = pt.models.VisCBC2D(data=train_ds,
data=train_ds,
title="CBC Iris Example", title="CBC Iris Example",
resolution=100, resolution=100,
axis_off=True, axis_off=True)
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(

View File

@@ -3,7 +3,6 @@
import argparse import argparse
import prototorch as pt import prototorch as pt
import prototorch.models.clcc
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
@@ -30,7 +29,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = prototorch.models.GLVQ( model = pt.models.GLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
@@ -42,13 +41,7 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Callbacks # Callbacks
vis = pt.models.Visualize2DVoronoiCallback( vis = pt.models.VisGLVQ2D(data=train_ds)
data=train_ds,
resolution=200,
title="Example: GLVQ on Iris",
x_label="sepal length",
y_label="petal length",
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
@@ -60,3 +53,13 @@ if __name__ == "__main__":
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./glvq_iris.ckpt")
# Load saved model
new_model = pt.models.GLVQ.load_from_checkpoint(
checkpoint_path="./glvq_iris.ckpt",
strict=False,
)
print(new_model)

View File

@@ -1,4 +1,4 @@
"""GLVQ example using the spiral dataset.""" """GMLVQ example using the spiral dataset."""
import argparse import argparse

104
examples/gtlvq_mnist.py Normal file
View File

@@ -0,0 +1,104 @@
"""GTLVQ example using the MNIST dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torchvision import transforms
from torchvision.datasets import MNIST
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = MNIST(
"~/datasets",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
test_ds = MNIST(
"~/datasets",
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=256)
test_loader = torch.utils.data.DataLoader(test_ds,
num_workers=0,
batch_size=256)
# Hyperparameters
num_classes = 10
prototypes_per_class = 1
hparams = dict(
input_dim=28 * 28,
latent_dim=28,
distribution=(num_classes, prototypes_per_class),
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = pt.models.ImageGTLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
#Use one batch of data for subspace initiator.
omega_initializer=pt.initializers.PCALinearTransformInitializer(
next(iter(train_loader))[0].reshape(256, 28 * 28)))
# Callbacks
vis = pt.models.VisImgComp(
data=train_ds,
num_columns=10,
show=False,
tensorboard=True,
random_data=100,
add_embedding=True,
embedding_data=200,
flatten_data=False,
)
pruning = pt.models.PruneLoserPrototypes(
threshold=0.01,
idle_epochs=1,
prune_quota_per_epoch=10,
frequency=1,
verbose=True,
)
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=15,
mode="min",
check_on_train_epoch_end=True,
)
# Setup trainer
# using GPUs here is strongly recommended!
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
# es,
],
terminate_on_nan=True,
weights_summary=None,
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)

63
examples/gtlvq_moons.py Normal file
View File

@@ -0,0 +1,63 @@
"""Localized-GTLVQ example using the Moons dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=256,
shuffle=True)
# Hyperparameters
# Latent_dim should be lower than input dim.
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1)
# Initialize the model
model = pt.models.GTLVQ(
hparams, prototypes_initializer=pt.initializers.SMCI(train_ds))
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
mode="max",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -10,6 +10,7 @@ from prototorch.utils.colors import hex_to_rgb
class Vis2DColorSOM(pl.Callback): class Vis2DColorSOM(pl.Callback):
def __init__(self, data, title="ColorSOMe", pause_time=0.1): def __init__(self, data, title="ColorSOMe", pause_time=0.1):
super().__init__() super().__init__()
self.title = title self.title = title

View File

@@ -8,6 +8,7 @@ import torch
class Backbone(torch.nn.Module): class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2): def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__() super().__init__()
self.input_size = input_size self.input_size = input_size

View File

@@ -8,6 +8,7 @@ import torch
class Backbone(torch.nn.Module): class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2): def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__() super().__init__()
self.input_size = input_size self.input_size = input_size

View File

@@ -0,0 +1,73 @@
"""Siamese GTLVQ example using all four dimensions of the Iris dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.latent_size = latent_size
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
self.activation = torch.nn.Sigmoid()
def forward(self, x):
x = self.activation(self.dense1(x))
out = self.activation(self.dense2(x))
return out
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(distribution=[1, 2, 3],
proto_lr=0.01,
bb_lr=0.01,
input_dim=2,
latent_dim=1)
# Initialize the backbone
backbone = Backbone(latent_size=hparams["input_dim"])
# Initialize the model
model = pt.models.SiameseGTLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone,
both_path_gradients=False,
)
# Model summary
print(model)
# Callbacks
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -8,17 +8,32 @@ from .glvq import (
GLVQ21, GLVQ21,
GMLVQ, GMLVQ,
GRLVQ, GRLVQ,
GTLVQ,
LGMLVQ, LGMLVQ,
LVQMLN, LVQMLN,
ImageGLVQ, ImageGLVQ,
ImageGMLVQ, ImageGMLVQ,
ImageGTLVQ,
SiameseGLVQ, SiameseGLVQ,
SiameseGMLVQ, SiameseGMLVQ,
SiameseGTLVQ,
) )
from .knn import KNN from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ from .lvq import (
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ LVQ1,
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas LVQ21,
MedianLVQ,
)
from .probabilistic import (
CELVQ,
RSLVQ,
SLVQ,
)
from .unsupervised import (
GrowingNeuralGas,
KohonenSOM,
NeuralGas,
)
from .vis import * from .vis import *
__version__ = "0.3.0" __version__ = "0.5.0"

View File

@@ -3,16 +3,18 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.core.competitions import WTAC
from prototorch.core.components import Components, LabeledComponents from ..core.competitions import WTAC
from prototorch.core.distances import euclidean_distance from ..core.components import Components, LabeledComponents
from prototorch.core.initializers import LabelsInitializer from ..core.distances import euclidean_distance
from prototorch.core.pooling import stratified_min_pooling from ..core.initializers import LabelsInitializer, ZerosCompInitializer
from prototorch.nn.wrappers import LambdaLayer from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
class ProtoTorchBolt(pl.LightningModule): class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts.""" """All ProtoTorch models are ProtoTorch Bolts."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__() super().__init__()
@@ -41,7 +43,7 @@ class ProtoTorchBolt(pl.LightningModule):
return optimizer return optimizer
def reconfigure_optimizers(self): def reconfigure_optimizers(self):
self.trainer.accelerator.setup_optimizers(self.trainer) self.trainer.strategy.setup_optimizers(self.trainer)
def __repr__(self): def __repr__(self):
surep = super().__repr__() surep = super().__repr__()
@@ -51,6 +53,7 @@ class ProtoTorchBolt(pl.LightningModule):
class PrototypeModel(ProtoTorchBolt): class PrototypeModel(ProtoTorchBolt):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@@ -72,14 +75,17 @@ class PrototypeModel(ProtoTorchBolt):
def add_prototypes(self, *args, **kwargs): def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs) self.proto_layer.add_components(*args, **kwargs)
self.hparams.distribution = self.proto_layer.distribution
self.reconfigure_optimizers() self.reconfigure_optimizers()
def remove_prototypes(self, indices): def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices) self.proto_layer.remove_components(indices)
self.hparams.distribution = self.proto_layer.distribution
self.reconfigure_optimizers() self.reconfigure_optimizers()
class UnsupervisedPrototypeModel(PrototypeModel): class UnsupervisedPrototypeModel(PrototypeModel):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@@ -102,19 +108,33 @@ class UnsupervisedPrototypeModel(PrototypeModel):
class SupervisedPrototypeModel(PrototypeModel): class SupervisedPrototypeModel(PrototypeModel):
def __init__(self, hparams, **kwargs):
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Layers # Layers
distribution = hparams.get("distribution", None)
prototypes_initializer = kwargs.get("prototypes_initializer", None) prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer", labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer()) LabelsInitializer())
if not skip_proto_layer:
# when subclasses do not need a customized prototype layer
if prototypes_initializer is not None: if prototypes_initializer is not None:
# when building a new model
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution, distribution=distribution,
components_initializer=prototypes_initializer, components_initializer=prototypes_initializer,
labels_initializer=labels_initializer, labels_initializer=labels_initializer,
) )
proto_shape = self.proto_layer.components.shape[1:]
self.hparams.initialized_proto_shape = proto_shape
else:
# when restoring a checkpointed model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=ZerosCompInitializer(
self.hparams.initialized_proto_shape),
)
self.competition_layer = WTAC() self.competition_layer = WTAC()
@property @property
@@ -134,7 +154,7 @@ class SupervisedPrototypeModel(PrototypeModel):
distances = self.compute_distances(x) distances = self.compute_distances(x)
_, plabels = self.proto_layer() _, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels) winning = stratified_min_pooling(distances, plabels)
y_pred = torch.nn.functional.softmin(winning) y_pred = torch.nn.functional.softmin(winning, dim=1)
return y_pred return y_pred
def predict_from_distances(self, distances): def predict_from_distances(self, distances):
@@ -168,3 +188,33 @@ class SupervisedPrototypeModel(PrototypeModel):
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
self.log("test_acc", accuracy) self.log("test_acc", accuracy)
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@@ -4,13 +4,14 @@ import logging
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.core.components import Components
from prototorch.core.initializers import LiteralCompInitializer
from ..core.components import Components
from ..core.initializers import LiteralCompInitializer
from .extras import ConnectionTopology from .extras import ConnectionTopology
class PruneLoserPrototypes(pl.Callback): class PruneLoserPrototypes(pl.Callback):
def __init__(self, def __init__(self,
threshold=0.01, threshold=0.01,
idle_epochs=10, idle_epochs=10,
@@ -67,6 +68,7 @@ class PruneLoserPrototypes(pl.Callback):
class PrototypeConvergence(pl.Callback): class PrototypeConvergence(pl.Callback):
def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False): def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False):
self.min_delta = min_delta self.min_delta = min_delta
self.idle_epochs = idle_epochs # epochs to wait self.idle_epochs = idle_epochs # epochs to wait
@@ -89,6 +91,7 @@ class GNGCallback(pl.Callback):
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke. Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
""" """
def __init__(self, reduction=0.1, freq=10): def __init__(self, reduction=0.1, freq=10):
self.reduction = reduction self.reduction = reduction
self.freq = freq self.freq = freq
@@ -134,4 +137,4 @@ class GNGCallback(pl.Callback):
pl_module.errors[ pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator.setup_optimizers(trainer) trainer.strategy.setup_optimizers(trainer)

View File

@@ -1,20 +1,21 @@
import torch import torch
import torchmetrics import torchmetrics
from prototorch.core.competitions import CBCC
from prototorch.core.components import ReasoningComponents
from prototorch.core.initializers import RandomReasoningsInitializer
from prototorch.core.losses import MarginLoss
from prototorch.core.similarities import euclidean_similarity
from prototorch.nn.wrappers import LambdaLayer
from ..core.competitions import CBCC
from ..core.components import ReasoningComponents
from ..core.initializers import RandomReasoningsInitializer
from ..core.losses import MarginLoss
from ..core.similarities import euclidean_similarity
from ..nn.wrappers import LambdaLayer
from .abstract import ImagePrototypesMixin
from .glvq import SiameseGLVQ from .glvq import SiameseGLVQ
from .mixin import ImagePrototypesMixin
class CBC(SiameseGLVQ): class CBC(SiameseGLVQ):
"""Classification-By-Components.""" """Classification-By-Components."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, skip_proto_layer=True, **kwargs)
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity) similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
components_initializer = kwargs.get("components_initializer", None) components_initializer = kwargs.get("components_initializer", None)

View File

@@ -1,86 +0,0 @@
from dataclasses import dataclass
from typing import Callable
import torch
from prototorch.core.competitions import WTAC
from prototorch.core.components import LabeledComponents
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import AbstractComponentsInitializer, LabelsInitializer
from prototorch.core.losses import GLVQLoss
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.nn.wrappers import LambdaLayer
@dataclass
class GLVQhparams:
distribution: dict
component_initializer: AbstractComponentsInitializer
distance_fn: Callable = euclidean_distance
lr: float = 0.01
margin: float = 0.0
# TODO: make nicer
transfer_fn: str = "identity"
transfer_beta: float = 10.0
optimizer: torch.optim.Optimizer = torch.optim.Adam
class GLVQ(CLCCScheme):
def __init__(self, hparams: GLVQhparams) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Initializers
def init_components(self, hparams):
# initialize Component Layer
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
def init_comparison(self, hparams):
# initialize Distance Layer
self.comparison_layer = LambdaLayer(hparams.distance_fn)
def init_inference(self, hparams):
self.competition_layer = WTAC()
def init_loss(self, hparams):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
beta=hparams.transfer_beta,
)
# Steps
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(batch_tensor, comp_tensor)
return distances
def inference(self, comparisonmeasures, components):
comp_labels = components[1]
return self.competition_layer(comparisonmeasures, comp_labels)
def loss(self, comparisonmeasures, batch, components):
target = batch[1]
comp_labels = components[1]
return self.loss_layer(comparisonmeasures, target, comp_labels)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr)
# Properties
@property
def prototypes(self):
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
return self.components_layer.labels.detach().cpu()

View File

@@ -1,192 +0,0 @@
"""
CLCC Scheme
CLCC is a LVQ scheme containing 4 steps
- Components
- Latent Space
- Comparison
- Competition
"""
from typing import Dict, Set, Type
import pytorch_lightning as pl
import torch
import torchmetrics
class CLCCScheme(pl.LightningModule):
registered_metrics: Dict[Type[torchmetrics.Metric],
torchmetrics.Metric] = {}
registered_metric_names: Dict[Type[torchmetrics.Metric], Set[str]] = {}
def __init__(self, hparams) -> None:
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# Initialize Model Metrics
self.init_model_metrics()
# internal API, called by models and callbacks
def register_torchmetric(self, name: str, metric: torchmetrics.Metric):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric()
self.registered_metric_names[metric] = {name}
else:
self.registered_metric_names[metric].add(name)
# external API
def get_competion(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competion(batch, components)
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competion(batch, components)
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Type hints
# TODO: Docs
def init_components(self, hparams):
...
def init_latent(self, hparams):
...
def init_comparison(self, hparams):
...
def init_competition(self, hparams):
...
def init_loss(self, hparams):
...
def init_inference(self, hparams):
...
def init_model_metrics(self):
self.register_torchmetric('train_accuracy', torchmetrics.Accuracy)
# Empty Steps
# TODO: Type hints
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def latent(self, batch, components):
"""
The latent step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
Takes a batch of size N and the componentsset of size M.
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
def competition(self, comparisonmeasures, components):
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
def loss(self, comparisonmeasures, batch, components):
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
def inference(self, comparisonmeasures, components):
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
def update_metrics_step(self, batch):
x, y = batch
preds = self(x)
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance(y, preds)
for name in self.registered_metric_names[metric]:
self.log(name, value)
def update_metrics_epoch(self):
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance.compute()
for name in self.registered_metric_names[metric]:
self.log(name, value)
# Lightning Hooks
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch)
return self.loss_forward(batch)
def train_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)

View File

@@ -1,76 +0,0 @@
from typing import Optional
import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.models.vis import Visualize2DVoronoiCallback
# NEW STUFF
# ##############################################################################
# TODO: Metrics
class MetricsTestCallback(pl.Callback):
metric_name = "test_cb_acc"
def setup(self,
trainer: pl.Trainer,
pl_module: CLCCScheme,
stage: Optional[str] = None) -> None:
pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy)
def on_epoch_end(self, trainer: pl.Trainer,
pl_module: pl.LightningModule) -> None:
metric = trainer.logged_metrics[self.metric_name]
if metric > 0.95:
trainer.should_stop = True
# TODO: Pruning
# ##############################################################################
if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=64,
num_workers=8)
components_initializer = SMCI(train_ds)
hparams = GLVQhparams(
distribution=dict(
num_classes=3,
per_class=2,
),
component_initializer=components_initializer,
)
model = GLVQ(hparams)
print(model)
# Callbacks
vis = Visualize2DVoronoiCallback(
data=train_ds,
resolution=500,
)
metrics = MetricsTestCallback()
# Train
trainer = pl.Trainer(
callbacks=[
#vis,
metrics,
],
gpus=1,
max_epochs=100,
weights_summary=None,
log_every_n_steps=1,
)
trainer.fit(model, train_loader)

View File

@@ -5,7 +5,8 @@ Modules not yet available in prototorch go here temporarily.
""" """
import torch import torch
from prototorch.core.similarities import gaussian
from ..core.similarities import gaussian
def rank_scaled_gaussian(distances, lambd): def rank_scaled_gaussian(distances, lambd):
@@ -14,7 +15,46 @@ def rank_scaled_gaussian(distances, lambd):
return torch.exp(-torch.exp(-ranks / lambd) * distances) return torch.exp(-torch.exp(-ranks / lambd) * distances)
def orthogonalization(tensors):
"""Orthogonalization via polar decomposition """
u, _, v = torch.svd(tensors, compute_uv=True)
u_shape = tuple(list(u.shape))
v_shape = tuple(list(v.shape))
# reshape to (num x N x M)
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
out = u @ v.permute([0, 2, 1])
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
return out
def ltangent_distance(x, y, omegas):
r"""Localized Tangent distance.
Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
:param `torch.tensor` omegas: Three dimensional matrix
:rtype: `torch.tensor`
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
omegas, omegas.permute([0, 2, 1]))
projected_x = x @ p
projected_y = torch.diagonal(y @ p).T
expanded_y = torch.unsqueeze(projected_y, dim=1)
batchwise_difference = expanded_y - projected_x
differences_squared = batchwise_difference**2
distances = torch.sqrt(torch.sum(differences_squared, dim=2))
distances = distances.permute(1, 0)
return distances
class GaussianPrior(torch.nn.Module): class GaussianPrior(torch.nn.Module):
def __init__(self, variance): def __init__(self, variance):
super().__init__() super().__init__()
self.variance = variance self.variance = variance
@@ -24,6 +64,7 @@ class GaussianPrior(torch.nn.Module):
class RankScaledGaussianPrior(torch.nn.Module): class RankScaledGaussianPrior(torch.nn.Module):
def __init__(self, lambd): def __init__(self, lambd):
super().__init__() super().__init__()
self.lambd = lambd self.lambd = lambd
@@ -33,6 +74,7 @@ class RankScaledGaussianPrior(torch.nn.Module):
class ConnectionTopology(torch.nn.Module): class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes): def __init__(self, agelimit, num_prototypes):
super().__init__() super().__init__()
self.agelimit = agelimit self.agelimit = agelimit

View File

@@ -1,20 +1,29 @@
"""Models based on the GLVQ framework.""" """Models based on the GLVQ framework."""
import torch import torch
from prototorch.core.competitions import wtac
from prototorch.core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from prototorch.core.initializers import EyeTransformInitializer
from prototorch.core.losses import GLVQLoss, lvq1_loss, lvq21_loss
from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .abstract import SupervisedPrototypeModel from ..core.competitions import wtac
from .mixin import ImagePrototypesMixin from ..core.distances import (
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from ..core.initializers import EyeLinearTransformInitializer
from ..core.losses import (
GLVQLoss,
lvq1_loss,
lvq21_loss,
)
from ..core.transforms import LinearTransform
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization
class GLVQ(SupervisedPrototypeModel): class GLVQ(SupervisedPrototypeModel):
"""Generalized Learning Vector Quantization.""" """Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@@ -30,6 +39,10 @@ class GLVQ(SupervisedPrototypeModel):
beta=self.hparams.transfer_beta, beta=self.hparams.transfer_beta,
) )
# def on_save_checkpoint(self, checkpoint):
# if "prototype_win_ratios" in checkpoint["state_dict"]:
# del checkpoint["state_dict"]["prototype_win_ratios"]
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.register_buffer( self.register_buffer(
"prototype_win_ratios", "prototype_win_ratios",
@@ -99,6 +112,7 @@ class SiameseGLVQ(GLVQ):
transformation pipeline are only learned from the inputs. transformation pipeline are only learned from the inputs.
""" """
def __init__(self, def __init__(self,
hparams, hparams,
backbone=torch.nn.Identity(), backbone=torch.nn.Identity(),
@@ -131,11 +145,15 @@ class SiameseGLVQ(GLVQ):
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos)) x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x) latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
latent_protos = self.backbone(protos) latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True) self.backbone.requires_grad_(bb_grad)
distances = self.distance_layer(latent_x, latent_protos) distances = self.distance_layer(latent_x, latent_protos)
return distances return distances
@@ -165,6 +183,7 @@ class LVQMLN(SiameseGLVQ):
rather in the embedding space. rather in the embedding space.
""" """
def compute_distances(self, x): def compute_distances(self, x):
latent_protos, _ = self.proto_layer() latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
@@ -180,6 +199,7 @@ class GRLVQ(SiameseGLVQ):
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise. TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
""" """
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@@ -205,15 +225,16 @@ class SiameseGMLVQ(SiameseGLVQ):
Implemented as a Siamese network with a linear transformation backbone. Implemented as a Siamese network with a linear transformation backbone.
""" """
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Override the backbone # Override the backbone
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get("omega_initializer",
EyeTransformInitializer()) EyeLinearTransformInitializer())
self.backbone = LinearTransform( self.backbone = LinearTransform(
self.hparams.input_dim, self.hparams.input_dim,
self.hparams.output_dim, self.hparams.latent_dim,
initializer=omega_initializer, initializer=omega_initializer,
) )
@@ -235,13 +256,14 @@ class GMLVQ(GLVQ):
function. This makes it easier to implement a localized variant. function. This makes it easier to implement a localized variant.
""" """
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance) distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters # Additional parameters
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get("omega_initializer",
EyeTransformInitializer()) EyeLinearTransformInitializer())
omega = omega_initializer.generate(self.hparams.input_dim, omega = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim) self.hparams.latent_dim)
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
@@ -269,6 +291,7 @@ class GMLVQ(GLVQ):
class LGMLVQ(GMLVQ): class LGMLVQ(GMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization.""" """Localized and Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", lomega_distance) distance_fn = kwargs.pop("distance_fn", lomega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
@@ -283,8 +306,48 @@ class LGMLVQ(GMLVQ):
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
class GTLVQ(LGMLVQ):
"""Localized and Generalized Tangent Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
omega_initializer = kwargs.get("omega_initializer")
if omega_initializer is not None:
subspace = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim)
omega = torch.repeat_interleave(subspace.unsqueeze(0),
self.num_prototypes,
dim=0)
else:
omega = torch.rand(
self.num_prototypes,
self.hparams.input_dim,
self.hparams.latent_dim,
device=self.device,
)
# Re-register `_omega` to override the one from the super class.
self.register_parameter("_omega", Parameter(omega))
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega))
class SiameseGTLVQ(SiameseGLVQ, GTLVQ):
"""Generalized Tangent Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
"""
class GLVQ1(GLVQ): class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1.""" """Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.loss = LossLayer(lvq1_loss) self.loss = LossLayer(lvq1_loss)
@@ -293,6 +356,7 @@ class GLVQ1(GLVQ):
class GLVQ21(GLVQ): class GLVQ21(GLVQ):
"""Generalized Learning Vector Quantization 2.1.""" """Generalized Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.loss = LossLayer(lvq21_loss) self.loss = LossLayer(lvq21_loss)
@@ -315,3 +379,18 @@ class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
after updates. after updates.
""" """
class ImageGTLVQ(ImagePrototypesMixin, GTLVQ):
"""GTLVQ for training on image data.
GTLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega))

View File

@@ -2,18 +2,21 @@
import warnings import warnings
from prototorch.core.competitions import KNNC from ..core.competitions import KNNC
from prototorch.core.components import LabeledComponents from ..core.components import LabeledComponents
from prototorch.core.initializers import LiteralCompInitializer, LiteralLabelsInitializer from ..core.initializers import (
from prototorch.utils.utils import parse_data_arg LiteralCompInitializer,
LiteralLabelsInitializer,
)
from ..utils.utils import parse_data_arg
from .abstract import SupervisedPrototypeModel from .abstract import SupervisedPrototypeModel
class KNN(SupervisedPrototypeModel): class KNN(SupervisedPrototypeModel):
"""K-Nearest-Neighbors classification algorithm.""" """K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, skip_proto_layer=True, **kwargs)
# Default hparams # Default hparams
self.hparams.setdefault("k", 1) self.hparams.setdefault("k", 1)
@@ -25,7 +28,7 @@ class KNN(SupervisedPrototypeModel):
# Layers # Layers
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
distribution=[], distribution=len(data) * [1],
components_initializer=LiteralCompInitializer(data), components_initializer=LiteralCompInitializer(data),
labels_initializer=LiteralLabelsInitializer(targets)) labels_initializer=LiteralLabelsInitializer(targets))
self.competition_layer = KNNC(k=self.hparams.k) self.competition_layer = KNNC(k=self.hparams.k)

View File

@@ -1,15 +1,15 @@
"""LVQ models that are optimized using non-gradient methods.""" """LVQ models that are optimized using non-gradient methods."""
from prototorch.core.losses import _get_dp_dm from ..core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation from ..nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin
from .glvq import GLVQ from .glvq import GLVQ
from .mixin import NonGradientMixin
class LVQ1(NonGradientMixin, GLVQ): class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1.""" """Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos, plables = self.proto_layer() protos, plables = self.proto_layer()
x, y = train_batch x, y = train_batch
@@ -39,6 +39,7 @@ class LVQ1(NonGradientMixin, GLVQ):
class LVQ21(NonGradientMixin, GLVQ): class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1.""" """Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos, plabels = self.proto_layer() protos, plabels = self.proto_layer()
@@ -71,6 +72,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
# TODO Avoid computing distances over and over # TODO Avoid computing distances over and over
""" """
def __init__(self, hparams, verbose=True, **kwargs): def __init__(self, hparams, verbose=True, **kwargs):
self.verbose = verbose self.verbose = verbose
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)

View File

@@ -1,27 +0,0 @@
class ProtoTorchMixin:
"""All mixins are ProtoTorchMixins."""
pass
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@@ -1,16 +1,17 @@
"""Probabilistic GLVQ methods""" """Probabilistic GLVQ methods"""
import torch import torch
from prototorch.core.losses import nllr_loss, rslvq_loss
from prototorch.core.pooling import stratified_min_pooling, stratified_sum_pooling
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from ..core.losses import nllr_loss, rslvq_loss
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
from ..nn.wrappers import LambdaLayer, LossLayer
from .extras import GaussianPrior, RankScaledGaussianPrior from .extras import GaussianPrior, RankScaledGaussianPrior
from .glvq import GLVQ, SiameseGMLVQ from .glvq import GLVQ, SiameseGMLVQ
class CELVQ(GLVQ): class CELVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization.""" """Cross-Entropy Learning Vector Quantization."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@@ -29,6 +30,7 @@ class CELVQ(GLVQ):
class ProbabilisticLVQ(GLVQ): class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs): def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@@ -62,18 +64,30 @@ class ProbabilisticLVQ(GLVQ):
class SLVQ(ProbabilisticLVQ): class SLVQ(ProbabilisticLVQ):
"""Soft Learning Vector Quantization.""" """Soft Learning Vector Quantization."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self.conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(nllr_loss) self.loss = LossLayer(nllr_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class RSLVQ(ProbabilisticLVQ): class RSLVQ(ProbabilisticLVQ):
"""Robust Soft Learning Vector Quantization.""" """Robust Soft Learning Vector Quantization."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self.conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(rslvq_loss) self.loss = LossLayer(rslvq_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ): class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
@@ -81,10 +95,15 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
TODO: Use Backbone LVQ instead TODO: Use Backbone LVQ instead
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conditional_distribution = RankScaledGaussianPrior(
self.hparams.lambd) # Default hparams
self.hparams.setdefault("lambda", 1.0)
lam = self.hparams.get("lambda", 1.0)
self.conditional_distribution = RankScaledGaussianPrior(lam)
self.loss = torch.nn.KLDivLoss() self.loss = torch.nn.KLDivLoss()
# FIXME # FIXME

View File

@@ -2,15 +2,14 @@
import numpy as np import numpy as np
import torch import torch
from prototorch.core.competitions import wtac
from prototorch.core.distances import squared_euclidean_distance
from prototorch.core.losses import NeuralGasEnergy
from prototorch.nn.wrappers import LambdaLayer
from .abstract import UnsupervisedPrototypeModel from ..core.competitions import wtac
from ..core.distances import squared_euclidean_distance
from ..core.losses import NeuralGasEnergy
from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
from .callbacks import GNGCallback from .callbacks import GNGCallback
from .extras import ConnectionTopology from .extras import ConnectionTopology
from .mixin import NonGradientMixin
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
@@ -19,6 +18,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
TODO Allow non-2D grids TODO Allow non-2D grids
""" """
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
h, w = hparams.get("shape") h, w = hparams.get("shape")
# Ignore `num_prototypes` # Ignore `num_prototypes`
@@ -35,7 +35,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
# Additional parameters # Additional parameters
x, y = torch.arange(h), torch.arange(w) x, y = torch.arange(h), torch.arange(w)
grid = torch.stack(torch.meshgrid(x, y), dim=-1) grid = torch.stack(torch.meshgrid(x, y, indexing="ij"), dim=-1)
self.register_buffer("_grid", grid) self.register_buffer("_grid", grid)
self._sigma = self.hparams.sigma self._sigma = self.hparams.sigma
self._lr = self.hparams.lr self._lr = self.hparams.lr
@@ -58,8 +58,10 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
diff = x.unsqueeze(dim=1) - protos diff = x.unsqueeze(dim=1) - protos
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
updated_protos = protos + delta.sum(dim=0) updated_protos = protos + delta.sum(dim=0)
self.proto_layer.load_state_dict({"_components": updated_protos}, self.proto_layer.load_state_dict(
strict=False) {"_components": updated_protos},
strict=False,
)
def training_epoch_end(self, training_step_outputs): def training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp( self._sigma = self.hparams.sigma * np.exp(
@@ -70,6 +72,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
class HeskesSOM(UnsupervisedPrototypeModel): class HeskesSOM(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@@ -79,6 +82,7 @@ class HeskesSOM(UnsupervisedPrototypeModel):
class NeuralGas(UnsupervisedPrototypeModel): class NeuralGas(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@@ -86,12 +90,12 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
# Default hparams # Default hparams
self.hparams.setdefault("agelimit", 10) self.hparams.setdefault("age_limit", 10)
self.hparams.setdefault("lm", 1) self.hparams.setdefault("lm", 1)
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm) self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.topology_layer = ConnectionTopology( self.topology_layer = ConnectionTopology(
agelimit=self.hparams.agelimit, agelimit=self.hparams.age_limit,
num_prototypes=self.hparams.num_prototypes, num_prototypes=self.hparams.num_prototypes,
) )
@@ -111,6 +115,7 @@ class NeuralGas(UnsupervisedPrototypeModel):
class GrowingNeuralGas(NeuralGas): class GrowingNeuralGas(NeuralGas):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@@ -142,6 +147,8 @@ class GrowingNeuralGas(NeuralGas):
def configure_callbacks(self): def configure_callbacks(self):
return [ return [
GNGCallback(reduction=self.hparams.insert_reduction, GNGCallback(
freq=self.hparams.insert_freq) reduction=self.hparams.insert_reduction,
freq=self.hparams.insert_freq,
)
] ]

View File

@@ -5,19 +5,21 @@ import pytorch_lightning as pl
import torch import torch
import torchvision import torchvision
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.utils.utils import generate_mesh, mesh2d
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
COLOR_UNLABELED = 'w' from ..utils.colors import get_colors, get_legend_handles
from ..utils.utils import mesh2d
class Vis2DAbstract(pl.Callback): class Vis2DAbstract(pl.Callback):
def __init__(self, def __init__(self,
data, data=None,
title=None, title="Prototype Visualization",
x_label=None,
y_label=None,
cmap="viridis", cmap="viridis",
xlabel="Data dimension 1",
ylabel="Data dimension 2",
legend_labels=None,
border=0.1, border=0.1,
resolution=100, resolution=100,
flatten_data=True, flatten_data=True,
@@ -30,6 +32,7 @@ class Vis2DAbstract(pl.Callback):
block=False): block=False):
super().__init__() super().__init__()
if data:
if isinstance(data, Dataset): if isinstance(data, Dataset):
x, y = next(iter(DataLoader(data, batch_size=len(data)))) x, y = next(iter(DataLoader(data, batch_size=len(data))))
elif isinstance(data, torch.utils.data.DataLoader): elif isinstance(data, torch.utils.data.DataLoader):
@@ -46,10 +49,14 @@ class Vis2DAbstract(pl.Callback):
self.x_train = x self.x_train = x
self.y_train = y self.y_train = y
else:
self.x_train = None
self.y_train = None
self.title = title self.title = title
self.x_label = x_label self.xlabel = xlabel
self.y_label = y_label self.ylabel = ylabel
self.legend_labels = legend_labels
self.fig = plt.figure(self.title) self.fig = plt.figure(self.title)
self.cmap = cmap self.cmap = cmap
self.border = border self.border = border
@@ -62,8 +69,9 @@ class Vis2DAbstract(pl.Callback):
self.pause_time = pause_time self.pause_time = pause_time
self.block = block self.block = block
def show_on_current_epoch(self, trainer): def precheck(self, trainer):
if self.show_last_only and trainer.current_epoch != trainer.max_epochs - 1: if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1:
return False return False
return True return True
@@ -71,10 +79,8 @@ class Vis2DAbstract(pl.Callback):
ax = self.fig.gca() ax = self.fig.gca()
ax.cla() ax.cla()
ax.set_title(self.title) ax.set_title(self.title)
if self.x_label: ax.set_xlabel(self.xlabel)
ax.set_xlabel(self.x_label) ax.set_ylabel(self.ylabel)
if self.x_label:
ax.set_ylabel(self.y_label)
if self.axis_off: if self.axis_off:
ax.axis("off") ax.axis("off")
return ax return ax
@@ -117,81 +123,44 @@ class Vis2DAbstract(pl.Callback):
else: else:
plt.show(block=self.block) plt.show(block=self.block)
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
self.visualize(pl_module)
self.log_and_display(trainer, pl_module)
def on_train_end(self, trainer, pl_module): def on_train_end(self, trainer, pl_module):
plt.close() plt.close()
class Visualize2DVoronoiCallback(Vis2DAbstract): class VisGLVQ2D(Vis2DAbstract):
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
self.data_min = torch.min(self.x_train, axis=0).values
self.data_max = torch.max(self.x_train, axis=0).values
def current_span(self, proto_values):
proto_min = torch.min(proto_values, axis=0).values
proto_max = torch.max(proto_values, axis=0).values
overall_min = torch.minimum(proto_min, self.data_min)
overall_max = torch.maximum(proto_max, self.data_max)
return overall_min, overall_max
def get_voronoi_diagram(self, min, max, model):
mesh_input, (xx, yy) = generate_mesh(
min,
max,
border=self.border,
resolution=self.resolution,
device=model.device,
)
y_pred = model.predict(mesh_input)
return xx, yy, y_pred.reshape(xx.shape)
def on_epoch_end(self, trainer, pl_module):
if not self.show_on_current_epoch(trainer):
return True
# Extract Prototypes
proto_values = pl_module.prototypes
if hasattr(pl_module, "prototype_labels"):
proto_labels = pl_module.prototype_labels
else:
proto_labels = COLOR_UNLABELED
# Calculate Voronoi Diagram
overall_min, overall_max = self.current_span(proto_values)
xx, yy, y_pred = self.get_voronoi_diagram(
overall_min,
overall_max,
pl_module,
)
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax() ax = self.setup_ax()
ax.contourf( self.plot_protos(ax, protos, plabels)
xx.cpu(), if x_train is not None:
yy.cpu(), self.plot_data(ax, x_train, y_train)
y_pred.cpu(), mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
cmap=self.cmap, self.border, self.resolution)
alpha=0.35, else:
) mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.proto_layer._components
self.plot_data(ax, self.x_train, self.y_train) mesh_input = torch.from_numpy(mesh_input).type_as(_components)
self.plot_protos(ax, proto_values, proto_labels) y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
self.log_and_display(trainer, pl_module) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
class VisSiameseGLVQ2D(Vis2DAbstract): class VisSiameseGLVQ2D(Vis2DAbstract):
def __init__(self, *args, map_protos=True, **kwargs): def __init__(self, *args, map_protos=True, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.map_protos = map_protos self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module): def visualize(self, pl_module):
if not self.show_on_current_epoch(trainer):
return True
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
@@ -218,18 +187,14 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
y_pred = y_pred.cpu().reshape(xx.shape) y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisGMLVQ2D(Vis2DAbstract): class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs): def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.ev_proj = ev_proj self.ev_proj = ev_proj
def on_epoch_end(self, trainer, pl_module): def visualize(self, pl_module):
if not self.show_on_current_epoch(trainer):
return True
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
@@ -251,14 +216,28 @@ class VisGMLVQ2D(Vis2DAbstract):
if self.show_protos: if self.show_protos:
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, plabels)
self.log_and_display(trainer, pl_module)
class VisCBC2D(Vis2DAbstract):
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos))
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.components_layer._components
y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
class VisNG2D(Vis2DAbstract): class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.show_on_current_epoch(trainer):
return True
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy() cmat = pl_module.topology_layer.cmat.cpu().numpy()
@@ -277,10 +256,27 @@ class VisNG2D(Vis2DAbstract):
"k-", "k-",
) )
self.log_and_display(trainer, pl_module)
class VisSpectralProtos(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
ax = self.setup_ax()
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
for p, pl in zip(protos, plabels):
ax.plot(p, c=colors[int(pl)])
if self.legend_labels:
handles = get_legend_handles(
colors,
self.legend_labels,
marker="lines",
)
ax.legend(handles=handles)
class VisImgComp(Vis2DAbstract): class VisImgComp(Vis2DAbstract):
def __init__(self, def __init__(self,
*args, *args,
random_data=0, random_data=0,
@@ -333,14 +329,9 @@ class VisImgComp(Vis2DAbstract):
dataformats=self.dataformats, dataformats=self.dataformats,
) )
def on_epoch_end(self, trainer, pl_module): def visualize(self, pl_module):
if not self.show_on_current_epoch(trainer):
return True
if self.show: if self.show:
components = pl_module.components components = pl_module.components
grid = torchvision.utils.make_grid(components, grid = torchvision.utils.make_grid(components,
nrow=self.num_columns) nrow=self.num_columns)
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap) plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
self.log_and_display(trainer, pl_module)

View File

@@ -1,8 +1,23 @@
[isort]
profile = hug
src_paths = isort, test
[yapf] [yapf]
based_on_style = pep8 based_on_style = pep8
spaces_before_comment = 2 spaces_before_comment = 2
split_before_logical_operator = true split_before_logical_operator = true
[pylint]
disable =
too-many-arguments,
too-few-public-methods,
fixme,
[pycodestyle]
max-line-length = 79
[isort]
profile = hug
src_paths = isort, test
multi_line_output = 3
include_trailing_comma = True
force_grid_wrap = 3
use_parentheses = True
line_length = 79

View File

@@ -18,12 +18,12 @@ PLUGIN_NAME = "models"
PROJECT_URL = "https://github.com/si-cim/prototorch_models" PROJECT_URL = "https://github.com/si-cim/prototorch_models"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git" DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md") as fh: with open("README.md", "r") as fh:
long_description = fh.read() long_description = fh.read()
INSTALL_REQUIRES = [ INSTALL_REQUIRES = [
"prototorch>=0.7.0", "prototorch>=0.7.3",
"pytorch_lightning>=1.3.5", "pytorch_lightning>=1.6.0",
"torchmetrics", "torchmetrics",
] ]
CLI = [ CLI = [
@@ -54,7 +54,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
name=safe_name("prototorch_" + PLUGIN_NAME), name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.3.0", version="0.5.0",
description="Pre-packaged prototype-based " description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.", "machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description, long_description=long_description,
@@ -64,7 +64,7 @@ setup(
url=PROJECT_URL, url=PROJECT_URL,
download_url=DOWNLOAD_URL, download_url=DOWNLOAD_URL,
license="MIT", license="MIT",
python_requires=">=3.6", python_requires=">=3.7",
install_requires=INSTALL_REQUIRES, install_requires=INSTALL_REQUIRES,
extras_require={ extras_require={
"dev": DEV, "dev": DEV,
@@ -80,10 +80,10 @@ setup(
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Natural Language :: English", "Natural Language :: English",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.6",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries",

View File

@@ -1,14 +0,0 @@
"""prototorch.models test suite."""
import unittest
class TestDummy(unittest.TestCase):
def setUp(self):
pass
def test_dummy(self):
pass
def tearDown(self):
pass

195
tests/test_models.py Normal file
View File

@@ -0,0 +1,195 @@
"""prototorch.models test suite."""
import prototorch as pt
import pytest
import torch
def test_glvq_model_build():
model = pt.models.GLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_glvq1_model_build():
model = pt.models.GLVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_glvq21_model_build():
model = pt.models.GLVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_gmlvq_model_build():
model = pt.models.GMLVQ(
{
"distribution": (3, 2),
"input_dim": 2,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_grlvq_model_build():
model = pt.models.GRLVQ(
{
"distribution": (3, 2),
"input_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_gtlvq_model_build():
model = pt.models.GTLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_lgmlvq_model_build():
model = pt.models.LGMLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_image_glvq_model_build():
model = pt.models.ImageGLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_image_gmlvq_model_build():
model = pt.models.ImageGMLVQ(
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_image_gtlvq_model_build():
model = pt.models.ImageGMLVQ(
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_siamese_glvq_model_build():
model = pt.models.SiameseGLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_siamese_gmlvq_model_build():
model = pt.models.SiameseGMLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_siamese_gtlvq_model_build():
model = pt.models.SiameseGTLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_knn_model_build():
train_ds = pt.datasets.Iris(dims=[0, 2])
model = pt.models.KNN(dict(k=3), data=train_ds)
def test_lvq1_model_build():
model = pt.models.LVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_lvq21_model_build():
model = pt.models.LVQ21(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_median_lvq_model_build():
model = pt.models.MedianLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_celvq_model_build():
model = pt.models.CELVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_rslvq_model_build():
model = pt.models.RSLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_slvq_model_build():
model = pt.models.SLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_growing_neural_gas_model_build():
model = pt.models.GrowingNeuralGas(
{"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_kohonen_som_model_build():
model = pt.models.KohonenSOM(
{"shape": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_neural_gas_model_build():
model = pt.models.NeuralGas(
{"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2),
)