Compare commits
30 Commits
feature/ux
...
v0.5.0
Author | SHA1 | Date | |
---|---|---|---|
|
d6629c8792 | ||
|
ef65bd3789 | ||
|
d096eba2c9 | ||
|
dd34c57e2e | ||
|
5911f4dd90 | ||
|
dbfe315f4f | ||
|
9c90c902dc | ||
|
7d3f59e54b | ||
|
9da47b1dba | ||
|
41f0e77fc9 | ||
|
fab786a07e | ||
|
40bd7ed380 | ||
|
4941c2b89d | ||
|
ce14dec7e9 | ||
|
b31c8cc707 | ||
|
e21e6c7e02 | ||
|
dd696ea1e0 | ||
|
15e7232747 | ||
|
197b728c63 | ||
|
98892afee0 | ||
|
d5855dbe97 | ||
|
75a39f5b03 | ||
|
1a0e697b27 | ||
|
1a17193b35 | ||
|
aaa3c51e0a | ||
|
62c5974a85 | ||
|
1d26226a2f | ||
|
4232d0ed2a | ||
|
a9edf06507 | ||
|
d3bb430104 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.3.0
|
||||
current_version = 0.5.0
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||
|
15
.codacy.yml
15
.codacy.yml
@@ -1,15 +0,0 @@
|
||||
# To validate the contents of your configuration file
|
||||
# run the following command in the folder where the configuration file is located:
|
||||
# codacy-analysis-cli validate-configuration --directory `pwd`
|
||||
# To analyse, run:
|
||||
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
|
||||
---
|
||||
engines:
|
||||
pylintpython3:
|
||||
exclude_paths:
|
||||
- config/engines.yml
|
||||
remark-lint:
|
||||
exclude_paths:
|
||||
- config/engines.yml
|
||||
exclude_paths:
|
||||
- 'tests/**'
|
@@ -1,2 +0,0 @@
|
||||
comment:
|
||||
require_changes: yes
|
25
.github/workflows/examples.yml
vendored
Normal file
25
.github/workflows/examples.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
# Thi workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: examples
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'examples/**.py'
|
||||
jobs:
|
||||
cpu:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
- name: Run examples
|
||||
run: |
|
||||
./tests/test_examples.sh examples/
|
75
.github/workflows/pythonapp.yml
vendored
Normal file
75
.github/workflows/pythonapp.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
style:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
- uses: pre-commit/action@v2.0.3
|
||||
compatibility:
|
||||
needs: style
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
||||
os: [ubuntu-latest, windows-latest]
|
||||
exclude:
|
||||
- os: windows-latest
|
||||
python-version: "3.7"
|
||||
- os: windows-latest
|
||||
python-version: "3.8"
|
||||
- os: windows-latest
|
||||
python-version: "3.9"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest
|
||||
publish_pypi:
|
||||
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
|
||||
needs: compatibility
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
pip install wheel
|
||||
- name: Build package
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish a Python distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
@@ -3,7 +3,7 @@
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.0.1
|
||||
rev: v4.1.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
@@ -18,19 +18,19 @@ repos:
|
||||
- id: autoflake
|
||||
|
||||
- repo: http://github.com/PyCQA/isort
|
||||
rev: 5.9.3
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.910-1
|
||||
rev: v0.931
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: prototorch
|
||||
additional_dependencies: [types-pkg_resources]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: v0.31.0
|
||||
rev: v0.32.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
|
||||
@@ -42,10 +42,9 @@ repos:
|
||||
- id: python-check-blanket-noqa
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.29.0
|
||||
rev: v2.31.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py36-plus]
|
||||
|
||||
- repo: https://github.com/si-cim/gitlint
|
||||
rev: v0.15.2-unofficial
|
||||
|
44
.travis.yml
44
.travis.yml
@@ -1,44 +0,0 @@
|
||||
dist: bionic
|
||||
sudo: false
|
||||
language: python
|
||||
python:
|
||||
- 3.9
|
||||
- 3.8
|
||||
- 3.7
|
||||
- 3.6
|
||||
cache:
|
||||
directories:
|
||||
- "$HOME/.cache/pip"
|
||||
- "./tests/artifacts"
|
||||
- "$HOME/datasets"
|
||||
install:
|
||||
- pip install git+git://github.com/si-cim/prototorch@dev --progress-bar off
|
||||
- pip install .[all] --progress-bar off
|
||||
script:
|
||||
- coverage run -m pytest
|
||||
- ./tests/test_examples.sh examples/
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
|
||||
# Publish on PyPI
|
||||
jobs:
|
||||
include:
|
||||
- stage: build
|
||||
python: 3.9
|
||||
script: echo "Starting Pypi build"
|
||||
deploy:
|
||||
provider: pypi
|
||||
username: __token__
|
||||
distributions: "sdist bdist_wheel"
|
||||
password:
|
||||
secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
|
||||
on:
|
||||
tags: true
|
||||
skip_existing: true
|
||||
|
||||
# The password is encrypted with:
|
||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
||||
# See https://docs.travis-ci.com/user/deployment/pypi and
|
||||
# https://github.com/travis-ci/travis.rb#installation
|
||||
# for more details
|
||||
# Note: The encrypt command does not work well in ZSH.
|
@@ -1,6 +1,5 @@
|
||||
# ProtoTorch Models
|
||||
|
||||
[](https://travis-ci.com/github/si-cim/prototorch_models)
|
||||
[](https://github.com/si-cim/prototorch_models/releases)
|
||||
[](https://pypi.org/project/prototorch_models/)
|
||||
[](https://github.com/si-cim/prototorch_models/blob/master/LICENSE)
|
||||
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
#
|
||||
release = "0.3.0"
|
||||
release = "0.5.0"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
|
@@ -2,223 +2,252 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7ac5eff0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# A short tutorial for the `prototorch.models` plugin"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "beb83780",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Introduction"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43b74278",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework.\n",
|
||||
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
|
||||
"\n",
|
||||
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
|
||||
"\n",
|
||||
"With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4e5d1fad",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Basics"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1244b66b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dcb88e8a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import prototorch as pt\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"import torch"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1adbe2f8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Building Models"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "96663ab1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "819ba756",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = pt.models.GLVQ(\n",
|
||||
" hparams=dict(distribution=[1, 1, 1]),\n",
|
||||
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
|
||||
")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b37e97c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(model)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d2c86903",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n",
|
||||
"\n",
|
||||
"The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "45806052",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Data"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9d62c4c6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "504df02c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds = pt.datasets.Iris(dims=[0, 2])"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3b8e7756",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"type(train_ds)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bce43afa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds.data.shape, train_ds.targets.shape"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "26a83328",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "67b80fbe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c1185f31",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"type(train_loader)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9b5a8963",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x_batch, y_batch = next(iter(train_loader))\n",
|
||||
"print(f\"{x_batch=}, {y_batch=}\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dd492ee2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5176b055",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Training"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "46a7a506",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "279e75b7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer = pl.Trainer(max_epochs=2, weights_summary=None)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e496b492",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.fit(model, train_loader)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "497fbff6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### From data to a trained model - a very minimal example"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ab069c5d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds = pt.datasets.Iris(dims=[0, 2])\n",
|
||||
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n",
|
||||
@@ -230,49 +259,239 @@
|
||||
"\n",
|
||||
"trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n",
|
||||
"trainer.fit(model, train_loader)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "30c71a93",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced"
|
||||
],
|
||||
"metadata": {}
|
||||
"### Saving/Loading trained models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f74ed2c1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initializing prototypes with a subset of a dataset (along with transformations)"
|
||||
],
|
||||
"metadata": {}
|
||||
"Pytorch Lightning can automatically checkpoint the model during various stages of training, but it also possible to manually save a checkpoint after training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3156658d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ckpt_path = \"./checkpoints/glvq_iris.ckpt\"\n",
|
||||
"trainer.save_checkpoint(ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c1c34055",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loaded_model = pt.models.GLVQ.load_from_checkpoint(ckpt_path, strict=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bbbb08e9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Visualizing decision boundaries in 2D"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "53ca52dc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisGLVQ2D(data=train_ds).visualize(loaded_model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8373531f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Saving/Loading trained weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "937bc458",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In most cases, the checkpointing workflow is sufficient. In some cases however, one might want to only save the trained weights from the model. The disadvantage of this method is that the model has be re-created using compatible initialization parameters before the weights could be loaded."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f2035af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ckpt_path = \"./checkpoints/glvq_iris_weights.pth\"\n",
|
||||
"torch.save(model.state_dict(), ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1206021a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = pt.models.GLVQ(\n",
|
||||
" dict(distribution=(3, 2)),\n",
|
||||
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9f2a4beb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisGLVQ2D(data=train_ds, title=\"Before loading the weights\").visualize(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "528d2fc2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch.load(ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec817e6b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.load_state_dict(torch.load(ckpt_path), strict=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a208eab7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisGLVQ2D(data=train_ds, title=\"After loading the weights\").visualize(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f8de748f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "53a64063",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Warm-start a model with prototypes learned from another model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3177c277",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
|
||||
"model = pt.models.SiameseGMLVQ(\n",
|
||||
" dict(input_dim=2,\n",
|
||||
" latent_dim=2,\n",
|
||||
" distribution=(3, 2),\n",
|
||||
" proto_lr=0.0001,\n",
|
||||
" bb_lr=0.0001),\n",
|
||||
" optimizer=torch.optim.Adam,\n",
|
||||
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
|
||||
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
|
||||
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8baee9a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc203088",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1f6a33a5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initializing prototypes with a subset of a dataset (along with transformations)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "946ce341",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import prototorch as pt\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"import torch\n",
|
||||
"from torchvision import transforms\n",
|
||||
"from torchvision.datasets import MNIST"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torchvision.utils import make_grid"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "510d9bd4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from matplotlib import pyplot as plt"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ea7c1228",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_ds = MNIST(\n",
|
||||
" \"~/datasets\",\n",
|
||||
@@ -284,59 +503,87 @@
|
||||
" transforms.ToTensor(),\n",
|
||||
" ]),\n",
|
||||
")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b9eaf5c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"s = int(0.05 * len(train_ds))\n",
|
||||
"init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8c32c9f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"init_ds"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "68a9a8b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = pt.models.ImageGLVQ(\n",
|
||||
" dict(distribution=(10, 5)),\n",
|
||||
" dict(distribution=(10, 1)),\n",
|
||||
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||
")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"plt.imshow(model.get_prototype_grid(num_columns=10))"
|
||||
],
|
||||
"id": "6f23df86",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
"source": [
|
||||
"plt.imshow(model.get_prototype_grid(num_columns=5))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c23c7b2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "30780927",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"protos, plabels = pt.components.LabeledComponents(\n",
|
||||
" distribution=(10, 5),\n",
|
||||
" components_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
|
||||
")()\n",
|
||||
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4fa69f92",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## FAQs"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fa20f9ac",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### How do I Retrieve the prototypes and their respective labels from the model?\n",
|
||||
"\n",
|
||||
@@ -351,11 +598,12 @@
|
||||
"```python\n",
|
||||
">>> model.prototype_labels\n",
|
||||
"```"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ba8215bf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### How do I make inferences/predictions/recall with my trained model?\n",
|
||||
"\n",
|
||||
@@ -370,13 +618,12 @@
|
||||
"```python\n",
|
||||
">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n",
|
||||
"```"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -390,7 +637,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.4"
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@@ -38,12 +38,10 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.Visualize2DVoronoiCallback(
|
||||
data=train_ds,
|
||||
title="CBC Iris Example",
|
||||
resolution=100,
|
||||
axis_off=True,
|
||||
)
|
||||
vis = pt.models.VisCBC2D(data=train_ds,
|
||||
title="CBC Iris Example",
|
||||
resolution=100,
|
||||
axis_off=True)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
|
@@ -3,7 +3,6 @@
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import prototorch.models.clcc
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
@@ -30,7 +29,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = prototorch.models.GLVQ(
|
||||
model = pt.models.GLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
@@ -42,13 +41,7 @@ if __name__ == "__main__":
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.Visualize2DVoronoiCallback(
|
||||
data=train_ds,
|
||||
resolution=200,
|
||||
title="Example: GLVQ on Iris",
|
||||
x_label="sepal length",
|
||||
y_label="petal length",
|
||||
)
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
@@ -60,3 +53,13 @@ if __name__ == "__main__":
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# Manual save
|
||||
trainer.save_checkpoint("./glvq_iris.ckpt")
|
||||
|
||||
# Load saved model
|
||||
new_model = pt.models.GLVQ.load_from_checkpoint(
|
||||
checkpoint_path="./glvq_iris.ckpt",
|
||||
strict=False,
|
||||
)
|
||||
print(new_model)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
"""GLVQ example using the spiral dataset."""
|
||||
"""GMLVQ example using the spiral dataset."""
|
||||
|
||||
import argparse
|
||||
|
104
examples/gtlvq_mnist.py
Normal file
104
examples/gtlvq_mnist.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""GTLVQ example using the MNIST dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = MNIST(
|
||||
"~/datasets",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
]),
|
||||
)
|
||||
test_ds = MNIST(
|
||||
"~/datasets",
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
]),
|
||||
)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=256)
|
||||
test_loader = torch.utils.data.DataLoader(test_ds,
|
||||
num_workers=0,
|
||||
batch_size=256)
|
||||
|
||||
# Hyperparameters
|
||||
num_classes = 10
|
||||
prototypes_per_class = 1
|
||||
hparams = dict(
|
||||
input_dim=28 * 28,
|
||||
latent_dim=28,
|
||||
distribution=(num_classes, prototypes_per_class),
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.ImageGTLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
#Use one batch of data for subspace initiator.
|
||||
omega_initializer=pt.initializers.PCALinearTransformInitializer(
|
||||
next(iter(train_loader))[0].reshape(256, 28 * 28)))
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisImgComp(
|
||||
data=train_ds,
|
||||
num_columns=10,
|
||||
show=False,
|
||||
tensorboard=True,
|
||||
random_data=100,
|
||||
add_embedding=True,
|
||||
embedding_data=200,
|
||||
flatten_data=False,
|
||||
)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.01,
|
||||
idle_epochs=1,
|
||||
prune_quota_per_epoch=10,
|
||||
frequency=1,
|
||||
verbose=True,
|
||||
)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=0.001,
|
||||
patience=15,
|
||||
mode="min",
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
# using GPUs here is strongly recommended!
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
pruning,
|
||||
# es,
|
||||
],
|
||||
terminate_on_nan=True,
|
||||
weights_summary=None,
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
63
examples/gtlvq_moons.py
Normal file
63
examples/gtlvq_moons.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Localized-GTLVQ example using the Moons dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=2)
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
batch_size=256,
|
||||
shuffle=True)
|
||||
|
||||
# Hyperparameters
|
||||
# Latent_dim should be lower than input dim.
|
||||
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GTLVQ(
|
||||
hparams, prototypes_initializer=pt.initializers.SMCI(train_ds))
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_acc",
|
||||
min_delta=0.001,
|
||||
patience=20,
|
||||
mode="max",
|
||||
verbose=False,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
es,
|
||||
],
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -10,6 +10,7 @@ from prototorch.utils.colors import hex_to_rgb
|
||||
|
||||
|
||||
class Vis2DColorSOM(pl.Callback):
|
||||
|
||||
def __init__(self, data, title="ColorSOMe", pause_time=0.1):
|
||||
super().__init__()
|
||||
self.title = title
|
||||
|
@@ -8,6 +8,7 @@ import torch
|
||||
|
||||
|
||||
class Backbone(torch.nn.Module):
|
||||
|
||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
|
@@ -8,6 +8,7 @@ import torch
|
||||
|
||||
|
||||
class Backbone(torch.nn.Module):
|
||||
|
||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
|
73
examples/siamese_gtlvq_iris.py
Normal file
73
examples/siamese_gtlvq_iris.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Siamese GTLVQ example using all four dimensions of the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class Backbone(torch.nn.Module):
|
||||
|
||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.latent_size = latent_size
|
||||
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
|
||||
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.activation(self.dense1(x))
|
||||
out = self.activation(self.dense2(x))
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris()
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=2)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(distribution=[1, 2, 3],
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
input_dim=2,
|
||||
latent_dim=1)
|
||||
|
||||
# Initialize the backbone
|
||||
backbone = Backbone(latent_size=hparams["input_dim"])
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.SiameseGTLVQ(
|
||||
hparams,
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
backbone=backbone,
|
||||
both_path_gradients=False,
|
||||
)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -8,17 +8,32 @@ from .glvq import (
|
||||
GLVQ21,
|
||||
GMLVQ,
|
||||
GRLVQ,
|
||||
GTLVQ,
|
||||
LGMLVQ,
|
||||
LVQMLN,
|
||||
ImageGLVQ,
|
||||
ImageGMLVQ,
|
||||
ImageGTLVQ,
|
||||
SiameseGLVQ,
|
||||
SiameseGMLVQ,
|
||||
SiameseGTLVQ,
|
||||
)
|
||||
from .knn import KNN
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
|
||||
from .lvq import (
|
||||
LVQ1,
|
||||
LVQ21,
|
||||
MedianLVQ,
|
||||
)
|
||||
from .probabilistic import (
|
||||
CELVQ,
|
||||
RSLVQ,
|
||||
SLVQ,
|
||||
)
|
||||
from .unsupervised import (
|
||||
GrowingNeuralGas,
|
||||
KohonenSOM,
|
||||
NeuralGas,
|
||||
)
|
||||
from .vis import *
|
||||
|
||||
__version__ = "0.3.0"
|
||||
__version__ = "0.5.0"
|
||||
|
@@ -3,16 +3,18 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.core.competitions import WTAC
|
||||
from prototorch.core.components import Components, LabeledComponents
|
||||
from prototorch.core.distances import euclidean_distance
|
||||
from prototorch.core.initializers import LabelsInitializer
|
||||
from prototorch.core.pooling import stratified_min_pooling
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
|
||||
from ..core.competitions import WTAC
|
||||
from ..core.components import Components, LabeledComponents
|
||||
from ..core.distances import euclidean_distance
|
||||
from ..core.initializers import LabelsInitializer, ZerosCompInitializer
|
||||
from ..core.pooling import stratified_min_pooling
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
|
||||
|
||||
class ProtoTorchBolt(pl.LightningModule):
|
||||
"""All ProtoTorch models are ProtoTorch Bolts."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
@@ -41,7 +43,7 @@ class ProtoTorchBolt(pl.LightningModule):
|
||||
return optimizer
|
||||
|
||||
def reconfigure_optimizers(self):
|
||||
self.trainer.accelerator.setup_optimizers(self.trainer)
|
||||
self.trainer.strategy.setup_optimizers(self.trainer)
|
||||
|
||||
def __repr__(self):
|
||||
surep = super().__repr__()
|
||||
@@ -51,6 +53,7 @@ class ProtoTorchBolt(pl.LightningModule):
|
||||
|
||||
|
||||
class PrototypeModel(ProtoTorchBolt):
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -72,14 +75,17 @@ class PrototypeModel(ProtoTorchBolt):
|
||||
|
||||
def add_prototypes(self, *args, **kwargs):
|
||||
self.proto_layer.add_components(*args, **kwargs)
|
||||
self.hparams.distribution = self.proto_layer.distribution
|
||||
self.reconfigure_optimizers()
|
||||
|
||||
def remove_prototypes(self, indices):
|
||||
self.proto_layer.remove_components(indices)
|
||||
self.hparams.distribution = self.proto_layer.distribution
|
||||
self.reconfigure_optimizers()
|
||||
|
||||
|
||||
class UnsupervisedPrototypeModel(PrototypeModel):
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -102,19 +108,33 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
||||
|
||||
|
||||
class SupervisedPrototypeModel(PrototypeModel):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
|
||||
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Layers
|
||||
distribution = hparams.get("distribution", None)
|
||||
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||
labels_initializer = kwargs.get("labels_initializer",
|
||||
LabelsInitializer())
|
||||
if prototypes_initializer is not None:
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=self.hparams.distribution,
|
||||
components_initializer=prototypes_initializer,
|
||||
labels_initializer=labels_initializer,
|
||||
)
|
||||
if not skip_proto_layer:
|
||||
# when subclasses do not need a customized prototype layer
|
||||
if prototypes_initializer is not None:
|
||||
# when building a new model
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=distribution,
|
||||
components_initializer=prototypes_initializer,
|
||||
labels_initializer=labels_initializer,
|
||||
)
|
||||
proto_shape = self.proto_layer.components.shape[1:]
|
||||
self.hparams.initialized_proto_shape = proto_shape
|
||||
else:
|
||||
# when restoring a checkpointed model
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=distribution,
|
||||
components_initializer=ZerosCompInitializer(
|
||||
self.hparams.initialized_proto_shape),
|
||||
)
|
||||
self.competition_layer = WTAC()
|
||||
|
||||
@property
|
||||
@@ -134,7 +154,7 @@ class SupervisedPrototypeModel(PrototypeModel):
|
||||
distances = self.compute_distances(x)
|
||||
_, plabels = self.proto_layer()
|
||||
winning = stratified_min_pooling(distances, plabels)
|
||||
y_pred = torch.nn.functional.softmin(winning)
|
||||
y_pred = torch.nn.functional.softmin(winning, dim=1)
|
||||
return y_pred
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
@@ -168,3 +188,33 @@ class SupervisedPrototypeModel(PrototypeModel):
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||
|
||||
self.log("test_acc", accuracy)
|
||||
|
||||
|
||||
class ProtoTorchMixin(object):
|
||||
"""All mixins are ProtoTorchMixins."""
|
||||
|
||||
|
||||
class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||
"""Mixin for models with image prototypes."""
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
||||
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
||||
from torchvision.utils import make_grid
|
||||
grid = make_grid(self.components, nrow=num_columns)
|
||||
if return_channels_last:
|
||||
grid = grid.permute((1, 2, 0))
|
||||
return grid.cpu()
|
||||
|
@@ -4,13 +4,14 @@ import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from prototorch.core.components import Components
|
||||
from prototorch.core.initializers import LiteralCompInitializer
|
||||
|
||||
from ..core.components import Components
|
||||
from ..core.initializers import LiteralCompInitializer
|
||||
from .extras import ConnectionTopology
|
||||
|
||||
|
||||
class PruneLoserPrototypes(pl.Callback):
|
||||
|
||||
def __init__(self,
|
||||
threshold=0.01,
|
||||
idle_epochs=10,
|
||||
@@ -67,6 +68,7 @@ class PruneLoserPrototypes(pl.Callback):
|
||||
|
||||
|
||||
class PrototypeConvergence(pl.Callback):
|
||||
|
||||
def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False):
|
||||
self.min_delta = min_delta
|
||||
self.idle_epochs = idle_epochs # epochs to wait
|
||||
@@ -89,6 +91,7 @@ class GNGCallback(pl.Callback):
|
||||
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, reduction=0.1, freq=10):
|
||||
self.reduction = reduction
|
||||
self.freq = freq
|
||||
@@ -134,4 +137,4 @@ class GNGCallback(pl.Callback):
|
||||
pl_module.errors[
|
||||
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
||||
|
||||
trainer.accelerator.setup_optimizers(trainer)
|
||||
trainer.strategy.setup_optimizers(trainer)
|
||||
|
@@ -1,20 +1,21 @@
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.core.competitions import CBCC
|
||||
from prototorch.core.components import ReasoningComponents
|
||||
from prototorch.core.initializers import RandomReasoningsInitializer
|
||||
from prototorch.core.losses import MarginLoss
|
||||
from prototorch.core.similarities import euclidean_similarity
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
|
||||
from ..core.competitions import CBCC
|
||||
from ..core.components import ReasoningComponents
|
||||
from ..core.initializers import RandomReasoningsInitializer
|
||||
from ..core.losses import MarginLoss
|
||||
from ..core.similarities import euclidean_similarity
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
from .abstract import ImagePrototypesMixin
|
||||
from .glvq import SiameseGLVQ
|
||||
from .mixin import ImagePrototypesMixin
|
||||
|
||||
|
||||
class CBC(SiameseGLVQ):
|
||||
"""Classification-By-Components."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
super().__init__(hparams, skip_proto_layer=True, **kwargs)
|
||||
|
||||
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
||||
components_initializer = kwargs.get("components_initializer", None)
|
||||
|
@@ -1,86 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from prototorch.core.competitions import WTAC
|
||||
from prototorch.core.components import LabeledComponents
|
||||
from prototorch.core.distances import euclidean_distance
|
||||
from prototorch.core.initializers import AbstractComponentsInitializer, LabelsInitializer
|
||||
from prototorch.core.losses import GLVQLoss
|
||||
from prototorch.models.clcc.clcc_scheme import CLCCScheme
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
|
||||
|
||||
@dataclass
|
||||
class GLVQhparams:
|
||||
distribution: dict
|
||||
component_initializer: AbstractComponentsInitializer
|
||||
distance_fn: Callable = euclidean_distance
|
||||
lr: float = 0.01
|
||||
margin: float = 0.0
|
||||
# TODO: make nicer
|
||||
transfer_fn: str = "identity"
|
||||
transfer_beta: float = 10.0
|
||||
optimizer: torch.optim.Optimizer = torch.optim.Adam
|
||||
|
||||
|
||||
class GLVQ(CLCCScheme):
|
||||
def __init__(self, hparams: GLVQhparams) -> None:
|
||||
super().__init__(hparams)
|
||||
self.lr = hparams.lr
|
||||
self.optimizer = hparams.optimizer
|
||||
|
||||
# Initializers
|
||||
def init_components(self, hparams):
|
||||
# initialize Component Layer
|
||||
self.components_layer = LabeledComponents(
|
||||
distribution=hparams.distribution,
|
||||
components_initializer=hparams.component_initializer,
|
||||
labels_initializer=LabelsInitializer(),
|
||||
)
|
||||
|
||||
def init_comparison(self, hparams):
|
||||
# initialize Distance Layer
|
||||
self.comparison_layer = LambdaLayer(hparams.distance_fn)
|
||||
|
||||
def init_inference(self, hparams):
|
||||
self.competition_layer = WTAC()
|
||||
|
||||
def init_loss(self, hparams):
|
||||
self.loss_layer = GLVQLoss(
|
||||
margin=hparams.margin,
|
||||
transfer_fn=hparams.transfer_fn,
|
||||
beta=hparams.transfer_beta,
|
||||
)
|
||||
|
||||
# Steps
|
||||
def comparison(self, batch, components):
|
||||
comp_tensor, _ = components
|
||||
batch_tensor, _ = batch
|
||||
|
||||
comp_tensor = comp_tensor.unsqueeze(1)
|
||||
|
||||
distances = self.comparison_layer(batch_tensor, comp_tensor)
|
||||
|
||||
return distances
|
||||
|
||||
def inference(self, comparisonmeasures, components):
|
||||
comp_labels = components[1]
|
||||
return self.competition_layer(comparisonmeasures, comp_labels)
|
||||
|
||||
def loss(self, comparisonmeasures, batch, components):
|
||||
target = batch[1]
|
||||
comp_labels = components[1]
|
||||
return self.loss_layer(comparisonmeasures, target, comp_labels)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer(self.parameters(), lr=self.lr)
|
||||
|
||||
# Properties
|
||||
@property
|
||||
def prototypes(self):
|
||||
return self.components_layer.components.detach().cpu()
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.components_layer.labels.detach().cpu()
|
@@ -1,192 +0,0 @@
|
||||
"""
|
||||
CLCC Scheme
|
||||
|
||||
CLCC is a LVQ scheme containing 4 steps
|
||||
- Components
|
||||
- Latent Space
|
||||
- Comparison
|
||||
- Competition
|
||||
|
||||
"""
|
||||
from typing import Dict, Set, Type
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
|
||||
class CLCCScheme(pl.LightningModule):
|
||||
registered_metrics: Dict[Type[torchmetrics.Metric],
|
||||
torchmetrics.Metric] = {}
|
||||
registered_metric_names: Dict[Type[torchmetrics.Metric], Set[str]] = {}
|
||||
|
||||
def __init__(self, hparams) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Common Steps
|
||||
self.init_components(hparams)
|
||||
self.init_latent(hparams)
|
||||
self.init_comparison(hparams)
|
||||
self.init_competition(hparams)
|
||||
|
||||
# Train Steps
|
||||
self.init_loss(hparams)
|
||||
|
||||
# Inference Steps
|
||||
self.init_inference(hparams)
|
||||
|
||||
# Initialize Model Metrics
|
||||
self.init_model_metrics()
|
||||
|
||||
# internal API, called by models and callbacks
|
||||
def register_torchmetric(self, name: str, metric: torchmetrics.Metric):
|
||||
if metric not in self.registered_metrics:
|
||||
self.registered_metrics[metric] = metric()
|
||||
self.registered_metric_names[metric] = {name}
|
||||
else:
|
||||
self.registered_metric_names[metric].add(name)
|
||||
|
||||
# external API
|
||||
def get_competion(self, batch, components):
|
||||
latent_batch, latent_components = self.latent(batch, components)
|
||||
# TODO: => Latent Hook
|
||||
comparison_tensor = self.comparison(latent_batch, latent_components)
|
||||
# TODO: => Comparison Hook
|
||||
return comparison_tensor
|
||||
|
||||
def forward(self, batch):
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = (batch, None)
|
||||
# TODO: manage different datatypes?
|
||||
components = self.components_layer()
|
||||
# TODO: => Component Hook
|
||||
comparison_tensor = self.get_competion(batch, components)
|
||||
# TODO: => Competition Hook
|
||||
return self.inference(comparison_tensor, components)
|
||||
|
||||
def predict(self, batch):
|
||||
"""
|
||||
Alias for forward
|
||||
"""
|
||||
return self.forward(batch)
|
||||
|
||||
def loss_forward(self, batch):
|
||||
# TODO: manage different datatypes?
|
||||
components = self.components_layer()
|
||||
# TODO: => Component Hook
|
||||
comparison_tensor = self.get_competion(batch, components)
|
||||
# TODO: => Competition Hook
|
||||
return self.loss(comparison_tensor, batch, components)
|
||||
|
||||
# Empty Initialization
|
||||
# TODO: Type hints
|
||||
# TODO: Docs
|
||||
def init_components(self, hparams):
|
||||
...
|
||||
|
||||
def init_latent(self, hparams):
|
||||
...
|
||||
|
||||
def init_comparison(self, hparams):
|
||||
...
|
||||
|
||||
def init_competition(self, hparams):
|
||||
...
|
||||
|
||||
def init_loss(self, hparams):
|
||||
...
|
||||
|
||||
def init_inference(self, hparams):
|
||||
...
|
||||
|
||||
def init_model_metrics(self):
|
||||
self.register_torchmetric('train_accuracy', torchmetrics.Accuracy)
|
||||
|
||||
# Empty Steps
|
||||
# TODO: Type hints
|
||||
def components(self):
|
||||
"""
|
||||
This step has no input.
|
||||
|
||||
It returns the components.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The components step has no reasonable default.")
|
||||
|
||||
def latent(self, batch, components):
|
||||
"""
|
||||
The latent step receives the data batch and the components.
|
||||
It can transform both by an arbitrary function.
|
||||
|
||||
It returns the transformed batch and components, each of the same length as the original input.
|
||||
"""
|
||||
return batch, components
|
||||
|
||||
def comparison(self, batch, components):
|
||||
"""
|
||||
Takes a batch of size N and the componentsset of size M.
|
||||
|
||||
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The comparison step has no reasonable default.")
|
||||
|
||||
def competition(self, comparisonmeasures, components):
|
||||
"""
|
||||
Takes the tensor of comparison measures.
|
||||
|
||||
Assigns a competition vector to each class.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The competition step has no reasonable default.")
|
||||
|
||||
def loss(self, comparisonmeasures, batch, components):
|
||||
"""
|
||||
Takes the tensor of competition measures.
|
||||
|
||||
Calculates a single loss value
|
||||
"""
|
||||
raise NotImplementedError("The loss step has no reasonable default.")
|
||||
|
||||
def inference(self, comparisonmeasures, components):
|
||||
"""
|
||||
Takes the tensor of competition measures.
|
||||
|
||||
Returns the inferred vector.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The inference step has no reasonable default.")
|
||||
|
||||
def update_metrics_step(self, batch):
|
||||
x, y = batch
|
||||
preds = self(x)
|
||||
|
||||
for metric in self.registered_metrics:
|
||||
instance = self.registered_metrics[metric].to(self.device)
|
||||
value = instance(y, preds)
|
||||
|
||||
for name in self.registered_metric_names[metric]:
|
||||
self.log(name, value)
|
||||
|
||||
def update_metrics_epoch(self):
|
||||
for metric in self.registered_metrics:
|
||||
instance = self.registered_metrics[metric].to(self.device)
|
||||
value = instance.compute()
|
||||
|
||||
for name in self.registered_metric_names[metric]:
|
||||
self.log(name, value)
|
||||
|
||||
# Lightning Hooks
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
self.update_metrics_step(batch)
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def train_epoch_end(self, outs) -> None:
|
||||
self.update_metrics_epoch()
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self.loss_forward(batch)
|
@@ -1,76 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
|
||||
from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams
|
||||
from prototorch.models.clcc.clcc_scheme import CLCCScheme
|
||||
from prototorch.models.vis import Visualize2DVoronoiCallback
|
||||
|
||||
# NEW STUFF
|
||||
# ##############################################################################
|
||||
|
||||
|
||||
# TODO: Metrics
|
||||
class MetricsTestCallback(pl.Callback):
|
||||
metric_name = "test_cb_acc"
|
||||
|
||||
def setup(self,
|
||||
trainer: pl.Trainer,
|
||||
pl_module: CLCCScheme,
|
||||
stage: Optional[str] = None) -> None:
|
||||
pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy)
|
||||
|
||||
def on_epoch_end(self, trainer: pl.Trainer,
|
||||
pl_module: pl.LightningModule) -> None:
|
||||
metric = trainer.logged_metrics[self.metric_name]
|
||||
if metric > 0.95:
|
||||
trainer.should_stop = True
|
||||
|
||||
|
||||
# TODO: Pruning
|
||||
|
||||
# ##############################################################################
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
batch_size=64,
|
||||
num_workers=8)
|
||||
|
||||
components_initializer = SMCI(train_ds)
|
||||
|
||||
hparams = GLVQhparams(
|
||||
distribution=dict(
|
||||
num_classes=3,
|
||||
per_class=2,
|
||||
),
|
||||
component_initializer=components_initializer,
|
||||
)
|
||||
model = GLVQ(hparams)
|
||||
|
||||
print(model)
|
||||
# Callbacks
|
||||
vis = Visualize2DVoronoiCallback(
|
||||
data=train_ds,
|
||||
resolution=500,
|
||||
)
|
||||
metrics = MetricsTestCallback()
|
||||
|
||||
# Train
|
||||
trainer = pl.Trainer(
|
||||
callbacks=[
|
||||
#vis,
|
||||
metrics,
|
||||
],
|
||||
gpus=1,
|
||||
max_epochs=100,
|
||||
weights_summary=None,
|
||||
log_every_n_steps=1,
|
||||
)
|
||||
trainer.fit(model, train_loader)
|
@@ -5,7 +5,8 @@ Modules not yet available in prototorch go here temporarily.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from prototorch.core.similarities import gaussian
|
||||
|
||||
from ..core.similarities import gaussian
|
||||
|
||||
|
||||
def rank_scaled_gaussian(distances, lambd):
|
||||
@@ -14,7 +15,46 @@ def rank_scaled_gaussian(distances, lambd):
|
||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||
|
||||
|
||||
def orthogonalization(tensors):
|
||||
"""Orthogonalization via polar decomposition """
|
||||
u, _, v = torch.svd(tensors, compute_uv=True)
|
||||
u_shape = tuple(list(u.shape))
|
||||
v_shape = tuple(list(v.shape))
|
||||
|
||||
# reshape to (num x N x M)
|
||||
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
|
||||
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
|
||||
|
||||
out = u @ v.permute([0, 2, 1])
|
||||
|
||||
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def ltangent_distance(x, y, omegas):
|
||||
r"""Localized Tangent distance.
|
||||
Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
|
||||
Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
|
||||
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
||||
omegas, omegas.permute([0, 2, 1]))
|
||||
projected_x = x @ p
|
||||
projected_y = torch.diagonal(y @ p).T
|
||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||
batchwise_difference = expanded_y - projected_x
|
||||
differences_squared = batchwise_difference**2
|
||||
distances = torch.sqrt(torch.sum(differences_squared, dim=2))
|
||||
distances = distances.permute(1, 0)
|
||||
return distances
|
||||
|
||||
|
||||
class GaussianPrior(torch.nn.Module):
|
||||
|
||||
def __init__(self, variance):
|
||||
super().__init__()
|
||||
self.variance = variance
|
||||
@@ -24,6 +64,7 @@ class GaussianPrior(torch.nn.Module):
|
||||
|
||||
|
||||
class RankScaledGaussianPrior(torch.nn.Module):
|
||||
|
||||
def __init__(self, lambd):
|
||||
super().__init__()
|
||||
self.lambd = lambd
|
||||
@@ -33,6 +74,7 @@ class RankScaledGaussianPrior(torch.nn.Module):
|
||||
|
||||
|
||||
class ConnectionTopology(torch.nn.Module):
|
||||
|
||||
def __init__(self, agelimit, num_prototypes):
|
||||
super().__init__()
|
||||
self.agelimit = agelimit
|
||||
|
@@ -1,20 +1,29 @@
|
||||
"""Models based on the GLVQ framework."""
|
||||
|
||||
import torch
|
||||
from prototorch.core.competitions import wtac
|
||||
from prototorch.core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
||||
from prototorch.core.initializers import EyeTransformInitializer
|
||||
from prototorch.core.losses import GLVQLoss, lvq1_loss, lvq21_loss
|
||||
from prototorch.core.transforms import LinearTransform
|
||||
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .abstract import SupervisedPrototypeModel
|
||||
from .mixin import ImagePrototypesMixin
|
||||
from ..core.competitions import wtac
|
||||
from ..core.distances import (
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from ..core.initializers import EyeLinearTransformInitializer
|
||||
from ..core.losses import (
|
||||
GLVQLoss,
|
||||
lvq1_loss,
|
||||
lvq21_loss,
|
||||
)
|
||||
from ..core.transforms import LinearTransform
|
||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||
from .extras import ltangent_distance, orthogonalization
|
||||
|
||||
|
||||
class GLVQ(SupervisedPrototypeModel):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -30,6 +39,10 @@ class GLVQ(SupervisedPrototypeModel):
|
||||
beta=self.hparams.transfer_beta,
|
||||
)
|
||||
|
||||
# def on_save_checkpoint(self, checkpoint):
|
||||
# if "prototype_win_ratios" in checkpoint["state_dict"]:
|
||||
# del checkpoint["state_dict"]["prototype_win_ratios"]
|
||||
|
||||
def initialize_prototype_win_ratios(self):
|
||||
self.register_buffer(
|
||||
"prototype_win_ratios",
|
||||
@@ -99,6 +112,7 @@ class SiameseGLVQ(GLVQ):
|
||||
transformation pipeline are only learned from the inputs.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hparams,
|
||||
backbone=torch.nn.Identity(),
|
||||
@@ -131,11 +145,15 @@ class SiameseGLVQ(GLVQ):
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
|
||||
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
||||
|
||||
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
self.backbone.requires_grad_(True)
|
||||
self.backbone.requires_grad_(bb_grad)
|
||||
|
||||
distances = self.distance_layer(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
@@ -165,6 +183,7 @@ class LVQMLN(SiameseGLVQ):
|
||||
rather in the embedding space.
|
||||
|
||||
"""
|
||||
|
||||
def compute_distances(self, x):
|
||||
latent_protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
@@ -180,6 +199,7 @@ class GRLVQ(SiameseGLVQ):
|
||||
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -205,15 +225,16 @@ class SiameseGMLVQ(SiameseGLVQ):
|
||||
Implemented as a Siamese network with a linear transformation backbone.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Override the backbone
|
||||
omega_initializer = kwargs.get("omega_initializer",
|
||||
EyeTransformInitializer())
|
||||
EyeLinearTransformInitializer())
|
||||
self.backbone = LinearTransform(
|
||||
self.hparams.input_dim,
|
||||
self.hparams.output_dim,
|
||||
self.hparams.latent_dim,
|
||||
initializer=omega_initializer,
|
||||
)
|
||||
|
||||
@@ -235,13 +256,14 @@ class GMLVQ(GLVQ):
|
||||
function. This makes it easier to implement a localized variant.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
# Additional parameters
|
||||
omega_initializer = kwargs.get("omega_initializer",
|
||||
EyeTransformInitializer())
|
||||
EyeLinearTransformInitializer())
|
||||
omega = omega_initializer.generate(self.hparams.input_dim,
|
||||
self.hparams.latent_dim)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
@@ -269,6 +291,7 @@ class GMLVQ(GLVQ):
|
||||
|
||||
class LGMLVQ(GMLVQ):
|
||||
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", lomega_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
@@ -283,8 +306,48 @@ class LGMLVQ(GMLVQ):
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
|
||||
|
||||
class GTLVQ(LGMLVQ):
|
||||
"""Localized and Generalized Tangent Learning Vector Quantization."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
omega_initializer = kwargs.get("omega_initializer")
|
||||
|
||||
if omega_initializer is not None:
|
||||
subspace = omega_initializer.generate(self.hparams.input_dim,
|
||||
self.hparams.latent_dim)
|
||||
omega = torch.repeat_interleave(subspace.unsqueeze(0),
|
||||
self.num_prototypes,
|
||||
dim=0)
|
||||
else:
|
||||
omega = torch.rand(
|
||||
self.num_prototypes,
|
||||
self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Re-register `_omega` to override the one from the super class.
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
with torch.no_grad():
|
||||
self._omega.copy_(orthogonalization(self._omega))
|
||||
|
||||
|
||||
class SiameseGTLVQ(SiameseGLVQ, GTLVQ):
|
||||
"""Generalized Tangent Learning Vector Quantization.
|
||||
|
||||
Implemented as a Siamese network with a linear transformation backbone.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class GLVQ1(GLVQ):
|
||||
"""Generalized Learning Vector Quantization 1."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = LossLayer(lvq1_loss)
|
||||
@@ -293,6 +356,7 @@ class GLVQ1(GLVQ):
|
||||
|
||||
class GLVQ21(GLVQ):
|
||||
"""Generalized Learning Vector Quantization 2.1."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = LossLayer(lvq21_loss)
|
||||
@@ -315,3 +379,18 @@ class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
|
||||
after updates.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ImageGTLVQ(ImagePrototypesMixin, GTLVQ):
|
||||
"""GTLVQ for training on image data.
|
||||
|
||||
GTLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
after updates.
|
||||
|
||||
"""
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
with torch.no_grad():
|
||||
self._omega.copy_(orthogonalization(self._omega))
|
||||
|
@@ -2,18 +2,21 @@
|
||||
|
||||
import warnings
|
||||
|
||||
from prototorch.core.competitions import KNNC
|
||||
from prototorch.core.components import LabeledComponents
|
||||
from prototorch.core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
|
||||
from prototorch.utils.utils import parse_data_arg
|
||||
|
||||
from ..core.competitions import KNNC
|
||||
from ..core.components import LabeledComponents
|
||||
from ..core.initializers import (
|
||||
LiteralCompInitializer,
|
||||
LiteralLabelsInitializer,
|
||||
)
|
||||
from ..utils.utils import parse_data_arg
|
||||
from .abstract import SupervisedPrototypeModel
|
||||
|
||||
|
||||
class KNN(SupervisedPrototypeModel):
|
||||
"""K-Nearest-Neighbors classification algorithm."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
super().__init__(hparams, skip_proto_layer=True, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("k", 1)
|
||||
@@ -25,7 +28,7 @@ class KNN(SupervisedPrototypeModel):
|
||||
|
||||
# Layers
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=[],
|
||||
distribution=len(data) * [1],
|
||||
components_initializer=LiteralCompInitializer(data),
|
||||
labels_initializer=LiteralLabelsInitializer(targets))
|
||||
self.competition_layer = KNNC(k=self.hparams.k)
|
||||
|
@@ -1,15 +1,15 @@
|
||||
"""LVQ models that are optimized using non-gradient methods."""
|
||||
|
||||
from prototorch.core.losses import _get_dp_dm
|
||||
from prototorch.nn.activations import get_activation
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
|
||||
from ..core.losses import _get_dp_dm
|
||||
from ..nn.activations import get_activation
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
from .abstract import NonGradientMixin
|
||||
from .glvq import GLVQ
|
||||
from .mixin import NonGradientMixin
|
||||
|
||||
|
||||
class LVQ1(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 1."""
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
protos, plables = self.proto_layer()
|
||||
x, y = train_batch
|
||||
@@ -39,6 +39,7 @@ class LVQ1(NonGradientMixin, GLVQ):
|
||||
|
||||
class LVQ21(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 2.1."""
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
protos, plabels = self.proto_layer()
|
||||
|
||||
@@ -71,6 +72,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
|
||||
# TODO Avoid computing distances over and over
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, hparams, verbose=True, **kwargs):
|
||||
self.verbose = verbose
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
@@ -1,27 +0,0 @@
|
||||
class ProtoTorchMixin:
|
||||
"""All mixins are ProtoTorchMixins."""
|
||||
pass
|
||||
|
||||
|
||||
class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||
"""Mixin for models with image prototypes."""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
||||
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
||||
from torchvision.utils import make_grid
|
||||
grid = make_grid(self.components, nrow=num_columns)
|
||||
if return_channels_last:
|
||||
grid = grid.permute((1, 2, 0))
|
||||
return grid.cpu()
|
@@ -1,16 +1,17 @@
|
||||
"""Probabilistic GLVQ methods"""
|
||||
|
||||
import torch
|
||||
from prototorch.core.losses import nllr_loss, rslvq_loss
|
||||
from prototorch.core.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
||||
|
||||
from ..core.losses import nllr_loss, rslvq_loss
|
||||
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||
from .extras import GaussianPrior, RankScaledGaussianPrior
|
||||
from .glvq import GLVQ, SiameseGMLVQ
|
||||
|
||||
|
||||
class CELVQ(GLVQ):
|
||||
"""Cross-Entropy Learning Vector Quantization."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -29,6 +30,7 @@ class CELVQ(GLVQ):
|
||||
|
||||
|
||||
class ProbabilisticLVQ(GLVQ):
|
||||
|
||||
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -62,18 +64,30 @@ class ProbabilisticLVQ(GLVQ):
|
||||
|
||||
class SLVQ(ProbabilisticLVQ):
|
||||
"""Soft Learning Vector Quantization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("variance", 1.0)
|
||||
variance = self.hparams.get("variance")
|
||||
|
||||
self.conditional_distribution = GaussianPrior(variance)
|
||||
self.loss = LossLayer(nllr_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class RSLVQ(ProbabilisticLVQ):
|
||||
"""Robust Soft Learning Vector Quantization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("variance", 1.0)
|
||||
variance = self.hparams.get("variance")
|
||||
|
||||
self.conditional_distribution = GaussianPrior(variance)
|
||||
self.loss = LossLayer(rslvq_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
@@ -81,10 +95,15 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
|
||||
TODO: Use Backbone LVQ instead
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conditional_distribution = RankScaledGaussianPrior(
|
||||
self.hparams.lambd)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("lambda", 1.0)
|
||||
lam = self.hparams.get("lambda", 1.0)
|
||||
|
||||
self.conditional_distribution = RankScaledGaussianPrior(lam)
|
||||
self.loss = torch.nn.KLDivLoss()
|
||||
|
||||
# FIXME
|
||||
|
@@ -2,15 +2,14 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from prototorch.core.competitions import wtac
|
||||
from prototorch.core.distances import squared_euclidean_distance
|
||||
from prototorch.core.losses import NeuralGasEnergy
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
|
||||
from .abstract import UnsupervisedPrototypeModel
|
||||
from ..core.competitions import wtac
|
||||
from ..core.distances import squared_euclidean_distance
|
||||
from ..core.losses import NeuralGasEnergy
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
||||
from .callbacks import GNGCallback
|
||||
from .extras import ConnectionTopology
|
||||
from .mixin import NonGradientMixin
|
||||
|
||||
|
||||
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
@@ -19,6 +18,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
TODO Allow non-2D grids
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
h, w = hparams.get("shape")
|
||||
# Ignore `num_prototypes`
|
||||
@@ -35,7 +35,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
|
||||
# Additional parameters
|
||||
x, y = torch.arange(h), torch.arange(w)
|
||||
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
|
||||
grid = torch.stack(torch.meshgrid(x, y, indexing="ij"), dim=-1)
|
||||
self.register_buffer("_grid", grid)
|
||||
self._sigma = self.hparams.sigma
|
||||
self._lr = self.hparams.lr
|
||||
@@ -58,8 +58,10 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
diff = x.unsqueeze(dim=1) - protos
|
||||
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
||||
updated_protos = protos + delta.sum(dim=0)
|
||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||
strict=False)
|
||||
self.proto_layer.load_state_dict(
|
||||
{"_components": updated_protos},
|
||||
strict=False,
|
||||
)
|
||||
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
self._sigma = self.hparams.sigma * np.exp(
|
||||
@@ -70,6 +72,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
|
||||
|
||||
class HeskesSOM(UnsupervisedPrototypeModel):
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -79,6 +82,7 @@ class HeskesSOM(UnsupervisedPrototypeModel):
|
||||
|
||||
|
||||
class NeuralGas(UnsupervisedPrototypeModel):
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -86,12 +90,12 @@ class NeuralGas(UnsupervisedPrototypeModel):
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("agelimit", 10)
|
||||
self.hparams.setdefault("age_limit", 10)
|
||||
self.hparams.setdefault("lm", 1)
|
||||
|
||||
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
|
||||
self.topology_layer = ConnectionTopology(
|
||||
agelimit=self.hparams.agelimit,
|
||||
agelimit=self.hparams.age_limit,
|
||||
num_prototypes=self.hparams.num_prototypes,
|
||||
)
|
||||
|
||||
@@ -111,6 +115,7 @@ class NeuralGas(UnsupervisedPrototypeModel):
|
||||
|
||||
|
||||
class GrowingNeuralGas(NeuralGas):
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -142,6 +147,8 @@ class GrowingNeuralGas(NeuralGas):
|
||||
|
||||
def configure_callbacks(self):
|
||||
return [
|
||||
GNGCallback(reduction=self.hparams.insert_reduction,
|
||||
freq=self.hparams.insert_freq)
|
||||
GNGCallback(
|
||||
reduction=self.hparams.insert_reduction,
|
||||
freq=self.hparams.insert_freq,
|
||||
)
|
||||
]
|
||||
|
@@ -5,19 +5,21 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from matplotlib import pyplot as plt
|
||||
from prototorch.utils.utils import generate_mesh, mesh2d
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
COLOR_UNLABELED = 'w'
|
||||
from ..utils.colors import get_colors, get_legend_handles
|
||||
from ..utils.utils import mesh2d
|
||||
|
||||
|
||||
class Vis2DAbstract(pl.Callback):
|
||||
|
||||
def __init__(self,
|
||||
data,
|
||||
title=None,
|
||||
x_label=None,
|
||||
y_label=None,
|
||||
data=None,
|
||||
title="Prototype Visualization",
|
||||
cmap="viridis",
|
||||
xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2",
|
||||
legend_labels=None,
|
||||
border=0.1,
|
||||
resolution=100,
|
||||
flatten_data=True,
|
||||
@@ -30,26 +32,31 @@ class Vis2DAbstract(pl.Callback):
|
||||
block=False):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(data, Dataset):
|
||||
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
||||
elif isinstance(data, torch.utils.data.DataLoader):
|
||||
x = torch.tensor([])
|
||||
y = torch.tensor([])
|
||||
for x_b, y_b in data:
|
||||
x = torch.cat([x, x_b])
|
||||
y = torch.cat([y, y_b])
|
||||
if data:
|
||||
if isinstance(data, Dataset):
|
||||
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
||||
elif isinstance(data, torch.utils.data.DataLoader):
|
||||
x = torch.tensor([])
|
||||
y = torch.tensor([])
|
||||
for x_b, y_b in data:
|
||||
x = torch.cat([x, x_b])
|
||||
y = torch.cat([y, y_b])
|
||||
else:
|
||||
x, y = data
|
||||
|
||||
if flatten_data:
|
||||
x = x.reshape(len(x), -1)
|
||||
|
||||
self.x_train = x
|
||||
self.y_train = y
|
||||
else:
|
||||
x, y = data
|
||||
|
||||
if flatten_data:
|
||||
x = x.reshape(len(x), -1)
|
||||
|
||||
self.x_train = x
|
||||
self.y_train = y
|
||||
self.x_train = None
|
||||
self.y_train = None
|
||||
|
||||
self.title = title
|
||||
self.x_label = x_label
|
||||
self.y_label = y_label
|
||||
self.xlabel = xlabel
|
||||
self.ylabel = ylabel
|
||||
self.legend_labels = legend_labels
|
||||
self.fig = plt.figure(self.title)
|
||||
self.cmap = cmap
|
||||
self.border = border
|
||||
@@ -62,19 +69,18 @@ class Vis2DAbstract(pl.Callback):
|
||||
self.pause_time = pause_time
|
||||
self.block = block
|
||||
|
||||
def show_on_current_epoch(self, trainer):
|
||||
if self.show_last_only and trainer.current_epoch != trainer.max_epochs - 1:
|
||||
return False
|
||||
def precheck(self, trainer):
|
||||
if self.show_last_only:
|
||||
if trainer.current_epoch != trainer.max_epochs - 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
def setup_ax(self):
|
||||
ax = self.fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(self.title)
|
||||
if self.x_label:
|
||||
ax.set_xlabel(self.x_label)
|
||||
if self.x_label:
|
||||
ax.set_ylabel(self.y_label)
|
||||
ax.set_xlabel(self.xlabel)
|
||||
ax.set_ylabel(self.ylabel)
|
||||
if self.axis_off:
|
||||
ax.axis("off")
|
||||
return ax
|
||||
@@ -117,81 +123,44 @@ class Vis2DAbstract(pl.Callback):
|
||||
else:
|
||||
plt.show(block=self.block)
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
self.visualize(pl_module)
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
plt.close()
|
||||
|
||||
|
||||
class Visualize2DVoronoiCallback(Vis2DAbstract):
|
||||
def __init__(self, data, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
|
||||
self.data_min = torch.min(self.x_train, axis=0).values
|
||||
self.data_max = torch.max(self.x_train, axis=0).values
|
||||
|
||||
def current_span(self, proto_values):
|
||||
proto_min = torch.min(proto_values, axis=0).values
|
||||
proto_max = torch.max(proto_values, axis=0).values
|
||||
|
||||
overall_min = torch.minimum(proto_min, self.data_min)
|
||||
overall_max = torch.maximum(proto_max, self.data_max)
|
||||
|
||||
return overall_min, overall_max
|
||||
|
||||
def get_voronoi_diagram(self, min, max, model):
|
||||
mesh_input, (xx, yy) = generate_mesh(
|
||||
min,
|
||||
max,
|
||||
border=self.border,
|
||||
resolution=self.resolution,
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
y_pred = model.predict(mesh_input)
|
||||
return xx, yy, y_pred.reshape(xx.shape)
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.show_on_current_epoch(trainer):
|
||||
return True
|
||||
|
||||
# Extract Prototypes
|
||||
proto_values = pl_module.prototypes
|
||||
if hasattr(pl_module, "prototype_labels"):
|
||||
proto_labels = pl_module.prototype_labels
|
||||
else:
|
||||
proto_labels = COLOR_UNLABELED
|
||||
|
||||
# Calculate Voronoi Diagram
|
||||
overall_min, overall_max = self.current_span(proto_values)
|
||||
xx, yy, y_pred = self.get_voronoi_diagram(
|
||||
overall_min,
|
||||
overall_max,
|
||||
pl_module,
|
||||
)
|
||||
class VisGLVQ2D(Vis2DAbstract):
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
ax = self.setup_ax()
|
||||
ax.contourf(
|
||||
xx.cpu(),
|
||||
yy.cpu(),
|
||||
y_pred.cpu(),
|
||||
cmap=self.cmap,
|
||||
alpha=0.35,
|
||||
)
|
||||
|
||||
self.plot_data(ax, self.x_train, self.y_train)
|
||||
self.plot_protos(ax, proto_values, proto_labels)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
if x_train is not None:
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
|
||||
self.border, self.resolution)
|
||||
else:
|
||||
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
||||
_components = pl_module.proto_layer._components
|
||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||
y_pred = pl_module.predict(mesh_input)
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
|
||||
class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||
|
||||
def __init__(self, *args, map_protos=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.map_protos = map_protos
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.show_on_current_epoch(trainer):
|
||||
return True
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
@@ -218,18 +187,14 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
|
||||
class VisGMLVQ2D(Vis2DAbstract):
|
||||
|
||||
def __init__(self, *args, ev_proj=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ev_proj = ev_proj
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.show_on_current_epoch(trainer):
|
||||
return True
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
@@ -251,14 +216,28 @@ class VisGMLVQ2D(Vis2DAbstract):
|
||||
if self.show_protos:
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
class VisCBC2D(Vis2DAbstract):
|
||||
|
||||
def visualize(self, pl_module):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
protos = pl_module.components
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, "w")
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||
_components = pl_module.components_layer._components
|
||||
y_pred = pl_module.predict(
|
||||
torch.Tensor(mesh_input).type_as(_components))
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
|
||||
class VisNG2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.show_on_current_epoch(trainer):
|
||||
return True
|
||||
|
||||
def visualize(self, pl_module):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
protos = pl_module.prototypes
|
||||
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
||||
@@ -277,10 +256,27 @@ class VisNG2D(Vis2DAbstract):
|
||||
"k-",
|
||||
)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
class VisSpectralProtos(Vis2DAbstract):
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
ax = self.setup_ax()
|
||||
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
|
||||
for p, pl in zip(protos, plabels):
|
||||
ax.plot(p, c=colors[int(pl)])
|
||||
if self.legend_labels:
|
||||
handles = get_legend_handles(
|
||||
colors,
|
||||
self.legend_labels,
|
||||
marker="lines",
|
||||
)
|
||||
ax.legend(handles=handles)
|
||||
|
||||
|
||||
class VisImgComp(Vis2DAbstract):
|
||||
|
||||
def __init__(self,
|
||||
*args,
|
||||
random_data=0,
|
||||
@@ -333,14 +329,9 @@ class VisImgComp(Vis2DAbstract):
|
||||
dataformats=self.dataformats,
|
||||
)
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.show_on_current_epoch(trainer):
|
||||
return True
|
||||
|
||||
def visualize(self, pl_module):
|
||||
if self.show:
|
||||
components = pl_module.components
|
||||
grid = torchvision.utils.make_grid(components,
|
||||
nrow=self.num_columns)
|
||||
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
23
setup.cfg
23
setup.cfg
@@ -1,8 +1,23 @@
|
||||
[isort]
|
||||
profile = hug
|
||||
src_paths = isort, test
|
||||
|
||||
[yapf]
|
||||
based_on_style = pep8
|
||||
spaces_before_comment = 2
|
||||
split_before_logical_operator = true
|
||||
|
||||
[pylint]
|
||||
disable =
|
||||
too-many-arguments,
|
||||
too-few-public-methods,
|
||||
fixme,
|
||||
|
||||
|
||||
[pycodestyle]
|
||||
max-line-length = 79
|
||||
|
||||
[isort]
|
||||
profile = hug
|
||||
src_paths = isort, test
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = True
|
||||
force_grid_wrap = 3
|
||||
use_parentheses = True
|
||||
line_length = 79
|
||||
|
12
setup.py
12
setup.py
@@ -18,12 +18,12 @@ PLUGIN_NAME = "models"
|
||||
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
||||
|
||||
with open("README.md") as fh:
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
INSTALL_REQUIRES = [
|
||||
"prototorch>=0.7.0",
|
||||
"pytorch_lightning>=1.3.5",
|
||||
"prototorch>=0.7.3",
|
||||
"pytorch_lightning>=1.6.0",
|
||||
"torchmetrics",
|
||||
]
|
||||
CLI = [
|
||||
@@ -54,7 +54,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
|
||||
|
||||
setup(
|
||||
name=safe_name("prototorch_" + PLUGIN_NAME),
|
||||
version="0.3.0",
|
||||
version="0.5.0",
|
||||
description="Pre-packaged prototype-based "
|
||||
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
||||
long_description=long_description,
|
||||
@@ -64,7 +64,7 @@ setup(
|
||||
url=PROJECT_URL,
|
||||
download_url=DOWNLOAD_URL,
|
||||
license="MIT",
|
||||
python_requires=">=3.6",
|
||||
python_requires=">=3.7",
|
||||
install_requires=INSTALL_REQUIRES,
|
||||
extras_require={
|
||||
"dev": DEV,
|
||||
@@ -80,10 +80,10 @@ setup(
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Natural Language :: English",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
|
@@ -1,14 +0,0 @@
|
||||
"""prototorch.models test suite."""
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestDummy(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_dummy(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
195
tests/test_models.py
Normal file
195
tests/test_models.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""prototorch.models test suite."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def test_glvq_model_build():
|
||||
model = pt.models.GLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_glvq1_model_build():
|
||||
model = pt.models.GLVQ1(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_glvq21_model_build():
|
||||
model = pt.models.GLVQ1(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_gmlvq_model_build():
|
||||
model = pt.models.GMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 2,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_grlvq_model_build():
|
||||
model = pt.models.GRLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_gtlvq_model_build():
|
||||
model = pt.models.GTLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_lgmlvq_model_build():
|
||||
model = pt.models.LGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_image_glvq_model_build():
|
||||
model = pt.models.ImageGLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(16),
|
||||
)
|
||||
|
||||
|
||||
def test_image_gmlvq_model_build():
|
||||
model = pt.models.ImageGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 16,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(16),
|
||||
)
|
||||
|
||||
|
||||
def test_image_gtlvq_model_build():
|
||||
model = pt.models.ImageGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 16,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(16),
|
||||
)
|
||||
|
||||
|
||||
def test_siamese_glvq_model_build():
|
||||
model = pt.models.SiameseGLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(4),
|
||||
)
|
||||
|
||||
|
||||
def test_siamese_gmlvq_model_build():
|
||||
model = pt.models.SiameseGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(4),
|
||||
)
|
||||
|
||||
|
||||
def test_siamese_gtlvq_model_build():
|
||||
model = pt.models.SiameseGTLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(4),
|
||||
)
|
||||
|
||||
|
||||
def test_knn_model_build():
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
model = pt.models.KNN(dict(k=3), data=train_ds)
|
||||
|
||||
|
||||
def test_lvq1_model_build():
|
||||
model = pt.models.LVQ1(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_lvq21_model_build():
|
||||
model = pt.models.LVQ21(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_median_lvq_model_build():
|
||||
model = pt.models.MedianLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_celvq_model_build():
|
||||
model = pt.models.CELVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_rslvq_model_build():
|
||||
model = pt.models.RSLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_slvq_model_build():
|
||||
model = pt.models.SLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_growing_neural_gas_model_build():
|
||||
model = pt.models.GrowingNeuralGas(
|
||||
{"num_prototypes": 5},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_kohonen_som_model_build():
|
||||
model = pt.models.KohonenSOM(
|
||||
{"shape": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_neural_gas_model_build():
|
||||
model = pt.models.NeuralGas(
|
||||
{"num_prototypes": 5},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
)
|
Reference in New Issue
Block a user