Compare commits

..

3 Commits

Author SHA1 Message Date
Jensun Ravichandran
aeb6417c28
refactor: minor changes in probabilistic.py 2021-08-06 13:49:29 +02:00
Jensun Ravichandran
cb7fb91c95
feat: add binnam_xor.py 2021-07-15 18:19:28 +02:00
Jensun Ravichandran
823b05e390
feat: add neural additive model for binary classification 2021-07-14 20:07:34 +02:00
54 changed files with 1209 additions and 2316 deletions

View File

@ -1,13 +1,13 @@
[bumpversion]
current_version = 0.7.1
current_version = 0.2.0
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = {major}.{minor}.{patch}
message = build: bump version {current_version} → {new_version}
[bumpversion:file:pyproject.toml]
[bumpversion:file:setup.py]
[bumpversion:file:./src/prototorch/models/__init__.py]
[bumpversion:file:./prototorch/models/__init__.py]
[bumpversion:file:./docs/source/conf.py]

15
.codacy.yml Normal file
View File

@ -0,0 +1,15 @@
# To validate the contents of your configuration file
# run the following command in the folder where the configuration file is located:
# codacy-analysis-cli validate-configuration --directory `pwd`
# To analyse, run:
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
---
engines:
pylintpython3:
exclude_paths:
- config/engines.yml
remark-lint:
exclude_paths:
- config/engines.yml
exclude_paths:
- 'tests/**'

2
.codecov.yml Normal file
View File

@ -0,0 +1,2 @@
comment:
require_changes: yes

View File

@ -1,25 +0,0 @@
# Thi workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: examples
on:
push:
paths:
- "examples/**.py"
jobs:
cpu:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Run examples
run: |
./tests/test_examples.sh examples/

View File

@ -1,75 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: tests
on:
push:
pull_request:
branches: [master]
jobs:
style:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- uses: pre-commit/action@v3.0.0
compatibility:
needs: style
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
os: [ubuntu-latest, windows-latest]
exclude:
- os: windows-latest
python-version: "3.8"
- os: windows-latest
python-version: "3.9"
- os: windows-latest
python-version: "3.10"
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: |
pytest
publish_pypi:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs: compatibility
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
pip install build
- name: Build package
run: python -m build . -C verbose
- name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -2,8 +2,8 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
@ -12,42 +12,41 @@ repos:
- id: check-ast
- id: check-case-conflict
- repo: https://github.com/myint/autoflake
rev: v2.1.1
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
- repo: http://github.com/PyCQA/isort
rev: 5.12.0
- repo: http://github.com/PyCQA/isort
rev: 5.8.0
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.902
hooks:
- id: mypy
files: prototorch
additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.31.0
hooks:
- id: yapf
additional_dependencies: ["toml"]
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0
hooks:
- id: python-use-type-annotations
- id: python-no-log-warn
- id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v3.7.0
- repo: https://github.com/asottile/pyupgrade
rev: v2.19.4
hooks:
- id: pyupgrade
- repo: https://github.com/si-cim/gitlint
- repo: https://github.com/si-cim/gitlint
rev: v0.15.2-unofficial
hooks:
- id: gitlint

25
.travis.yml Normal file
View File

@ -0,0 +1,25 @@
dist: bionic
sudo: false
language: python
python: 3.9
cache:
directories:
- "$HOME/.cache/pip"
- "./tests/artifacts"
- "$HOME/datasets"
install:
- pip install git+git://github.com/si-cim/prototorch@dev --progress-bar off
- pip install .[all] --progress-bar off
script:
- coverage run -m pytest
- ./tests/test_examples.sh examples/
after_success:
- bash <(curl -s https://codecov.io/bash)
deploy:
provider: pypi
username: __token__
password:
secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
on:
tags: true
skip_existing: true

View File

@ -1,5 +1,6 @@
# ProtoTorch Models
[![Build Status](https://api.travis-ci.com/si-cim/prototorch_models.svg?branch=main)](https://travis-ci.com/github/si-cim/prototorch_models)
[![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch_models?color=yellow&label=version)](https://github.com/si-cim/prototorch_models/releases)
[![PyPI](https://img.shields.io/pypi/v/prototorch_models)](https://pypi.org/project/prototorch_models/)
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch_models)](https://github.com/si-cim/prototorch_models/blob/master/LICENSE)

View File

@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "0.7.1"
release = "0.2.0"
# -- General configuration ---------------------------------------------------

View File

@ -2,252 +2,223 @@
"cells": [
{
"cell_type": "markdown",
"id": "7ac5eff0",
"metadata": {},
"source": [
"# A short tutorial for the `prototorch.models` plugin"
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "beb83780",
"metadata": {},
"source": [
"## Introduction"
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "43b74278",
"metadata": {},
"source": [
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework.\n",
"\n",
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
"\n",
"With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this."
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "4e5d1fad",
"metadata": {},
"source": [
"## Basics"
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "1244b66b",
"metadata": {},
"source": [
"First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:"
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "dcb88e8a",
"metadata": {},
"outputs": [],
"source": [
"import prototorch as pt\n",
"import pytorch_lightning as pl\n",
"import torch"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "1adbe2f8",
"metadata": {},
"source": [
"### Building Models"
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "96663ab1",
"metadata": {},
"source": [
"Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer."
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "819ba756",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.GLVQ(\n",
" hparams=dict(distribution=[1, 1, 1]),\n",
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
")"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b37e97c",
"metadata": {},
"outputs": [],
"source": [
"print(model)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "d2c86903",
"metadata": {},
"source": [
"The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n",
"\n",
"The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros."
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "45806052",
"metadata": {},
"source": [
"### Data"
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "9d62c4c6",
"metadata": {},
"source": [
"The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets."
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "504df02c",
"metadata": {},
"outputs": [],
"source": [
"train_ds = pt.datasets.Iris(dims=[0, 2])"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b8e7756",
"metadata": {},
"outputs": [],
"source": [
"type(train_ds)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "bce43afa",
"metadata": {},
"outputs": [],
"source": [
"train_ds.data.shape, train_ds.targets.shape"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "26a83328",
"metadata": {},
"source": [
"Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly."
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "67b80fbe",
"metadata": {},
"outputs": [],
"source": [
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1185f31",
"metadata": {},
"outputs": [],
"source": [
"type(train_loader)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b5a8963",
"metadata": {},
"outputs": [],
"source": [
"x_batch, y_batch = next(iter(train_loader))\n",
"print(f\"{x_batch=}, {y_batch=}\")"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "dd492ee2",
"metadata": {},
"source": [
"This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches."
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "5176b055",
"metadata": {},
"source": [
"### Training"
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "46a7a506",
"metadata": {},
"source": [
"If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented."
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "279e75b7",
"metadata": {},
"outputs": [],
"source": [
"trainer = pl.Trainer(max_epochs=2, weights_summary=None)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "e496b492",
"metadata": {},
"outputs": [],
"source": [
"trainer.fit(model, train_loader)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "497fbff6",
"metadata": {},
"source": [
"### From data to a trained model - a very minimal example"
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab069c5d",
"metadata": {},
"outputs": [],
"source": [
"train_ds = pt.datasets.Iris(dims=[0, 2])\n",
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n",
@ -259,239 +230,49 @@
"\n",
"trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n",
"trainer.fit(model, train_loader)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "30c71a93",
"metadata": {},
"source": [
"### Saving/Loading trained models"
]
},
{
"cell_type": "markdown",
"id": "f74ed2c1",
"metadata": {},
"source": [
"Pytorch Lightning can automatically checkpoint the model during various stages of training, but it also possible to manually save a checkpoint after training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3156658d",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = \"./checkpoints/glvq_iris.ckpt\"\n",
"trainer.save_checkpoint(ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1c34055",
"metadata": {},
"outputs": [],
"source": [
"loaded_model = pt.models.GLVQ.load_from_checkpoint(ckpt_path, strict=False)"
]
},
{
"cell_type": "markdown",
"id": "bbbb08e9",
"metadata": {},
"source": [
"### Visualizing decision boundaries in 2D"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53ca52dc",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds).visualize(loaded_model)"
]
},
{
"cell_type": "markdown",
"id": "8373531f",
"metadata": {},
"source": [
"### Saving/Loading trained weights"
]
},
{
"cell_type": "markdown",
"id": "937bc458",
"metadata": {},
"source": [
"In most cases, the checkpointing workflow is sufficient. In some cases however, one might want to only save the trained weights from the model. The disadvantage of this method is that the model has be re-created using compatible initialization parameters before the weights could be loaded."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f2035af",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = \"./checkpoints/glvq_iris_weights.pth\"\n",
"torch.save(model.state_dict(), ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1206021a",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.GLVQ(\n",
" dict(distribution=(3, 2)),\n",
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f2a4beb",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds, title=\"Before loading the weights\").visualize(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "528d2fc2",
"metadata": {},
"outputs": [],
"source": [
"torch.load(ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec817e6b",
"metadata": {},
"outputs": [],
"source": [
"model.load_state_dict(torch.load(ckpt_path), strict=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a208eab7",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds, title=\"After loading the weights\").visualize(model)"
]
},
{
"cell_type": "markdown",
"id": "f8de748f",
"metadata": {},
"source": [
"## Advanced"
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "53a64063",
"metadata": {},
"source": [
"### Warm-start a model with prototypes learned from another model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3177c277",
"metadata": {},
"outputs": [],
"source": [
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
"model = pt.models.SiameseGMLVQ(\n",
" dict(input_dim=2,\n",
" latent_dim=2,\n",
" distribution=(3, 2),\n",
" proto_lr=0.0001,\n",
" bb_lr=0.0001),\n",
" optimizer=torch.optim.Adam,\n",
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8baee9a2",
"metadata": {},
"outputs": [],
"source": [
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc203088",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
]
},
{
"cell_type": "markdown",
"id": "1f6a33a5",
"metadata": {},
"source": [
"### Initializing prototypes with a subset of a dataset (along with transformations)"
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "946ce341",
"metadata": {},
"outputs": [],
"source": [
"import prototorch as pt\n",
"import pytorch_lightning as pl\n",
"import torch\n",
"from torchvision import transforms\n",
"from torchvision.datasets import MNIST\n",
"from torchvision.utils import make_grid"
]
"from torchvision.datasets import MNIST"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "510d9bd4",
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea7c1228",
"metadata": {},
"outputs": [],
"source": [
"train_ds = MNIST(\n",
" \"~/datasets\",\n",
@ -503,87 +284,59 @@
" transforms.ToTensor(),\n",
" ]),\n",
")"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b9eaf5c",
"metadata": {},
"outputs": [],
"source": [
"s = int(0.05 * len(train_ds))\n",
"init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c32c9f2",
"metadata": {},
"outputs": [],
"source": [
"init_ds"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "68a9a8b9",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.ImageGLVQ(\n",
" dict(distribution=(10, 1)),\n",
" dict(distribution=(10, 5)),\n",
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
")"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f23df86",
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(model.get_prototype_grid(num_columns=5))"
]
"plt.imshow(model.get_prototype_grid(num_columns=10))"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "1c23c7b2",
"metadata": {},
"source": [
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30780927",
"metadata": {},
"outputs": [],
"source": [
"protos, plabels = pt.components.LabeledComponents(\n",
" distribution=(10, 5),\n",
" components_initializer=pt.initializers.SMCI(init_ds),\n",
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
")()\n",
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
]
},
{
"cell_type": "markdown",
"id": "4fa69f92",
"metadata": {},
"source": [
"## FAQs"
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "fa20f9ac",
"metadata": {},
"source": [
"### How do I Retrieve the prototypes and their respective labels from the model?\n",
"\n",
@ -598,12 +351,11 @@
"```python\n",
">>> model.prototype_labels\n",
"```"
]
],
"metadata": {}
},
{
"cell_type": "markdown",
"id": "ba8215bf",
"metadata": {},
"source": [
"### How do I make inferences/predictions/recall with my trained model?\n",
"\n",
@ -618,12 +370,13 @@
"```python\n",
">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n",
"```"
]
],
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@ -637,7 +390,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.9.4"
}
},
"nbformat": 4,

View File

@ -0,0 +1,81 @@
"""Neural Additive Model (NAM) example for binary classification."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Tecator("~/datasets")
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(lr=0.1)
# Define the feature extractor
class FE(torch.nn.Module):
def __init__(self):
super().__init__()
self.modules_list = torch.nn.ModuleList([
torch.nn.Linear(1, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
])
def forward(self, x):
for m in self.modules_list:
x = m(x)
return x
# Initialize the model
model = pt.models.BinaryNAM(
hparams,
extractors=torch.nn.ModuleList([FE() for _ in range(100)]),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 100)
# Callbacks
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=20,
mode="min",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
es,
],
terminate_on_nan=True,
weights_summary=None,
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)
# Visualize extractor shape functions
fig, axes = plt.subplots(10, 10)
for i, ax in enumerate(axes.flat):
x = torch.linspace(-2, 2, 100) # TODO use min/max from data
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
ax.plot(x, y)
ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[])
plt.show()

86
examples/binnam_xor.py Normal file
View File

@ -0,0 +1,86 @@
"""Neural Additive Model (NAM) example for binary classification."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.XOR()
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
# Hyperparameters
hparams = dict(lr=0.001)
# Define the feature extractor
class FE(torch.nn.Module):
def __init__(self, hidden_size=10):
super().__init__()
self.modules_list = torch.nn.ModuleList([
torch.nn.Linear(1, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, 1),
torch.nn.ReLU(),
])
def forward(self, x):
for m in self.modules_list:
x = m(x)
return x
# Initialize the model
model = pt.models.BinaryNAM(
hparams,
extractors=torch.nn.ModuleList([FE(20) for _ in range(2)]),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks
vis = pt.models.Vis2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=50,
mode="min",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
terminate_on_nan=True,
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)
# Visualize extractor shape functions
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes.flat):
x = torch.linspace(0, 1, 100) # TODO use min/max from data
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
ax.plot(x, y)
ax.set(title=f"Feature {i + 1}")
plt.show()

View File

@ -1,32 +1,25 @@
"""CBC example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import CBC, VisCBC2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
import torch
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=32)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
# Hyperparameters
hparams = dict(
@ -37,32 +30,23 @@ if __name__ == "__main__":
)
# Initialize the model
model = CBC(
model = pt.models.CBC(
hparams,
components_initializer=pt.initializers.SSCI(train_ds, noise=0.1),
reasonings_initializer=pt.initializers.
components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
reasonings_iniitializer=pt.initializers.
PurePositiveReasoningsInitializer(),
)
# Callbacks
vis = VisCBC2D(
data=train_ds,
vis = pt.models.VisCBC2D(data=train_ds,
title="CBC Iris Example",
resolution=100,
axis_off=True,
)
axis_off=True)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
detect_anomaly=True,
log_every_n_steps=1,
max_epochs=1000,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
)
# Training loop

8
examples/cli/README.md Normal file
View File

@ -0,0 +1,8 @@
# Examples using Lightning CLI
Examples in this folder use the experimental [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html).
To use the example run
```
python gmlvq.py --config gmlvq.yaml
```

19
examples/cli/gmlvq.py Normal file
View File

@ -0,0 +1,19 @@
"""GMLVQ example using the MNIST dataset."""
import prototorch as pt
import torch
from prototorch.models import ImageGMLVQ
from prototorch.models.abstract import PrototypeModel
from prototorch.models.data import MNISTDataModule
from pytorch_lightning.utilities.cli import LightningCLI
class ExperimentClass(ImageGMLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.zeros(28 * 28),
**kwargs)
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)

11
examples/cli/gmlvq.yaml Normal file
View File

@ -0,0 +1,11 @@
model:
hparams:
input_dim: 784
latent_dim: 784
distribution:
num_classes: 10
prototypes_per_class: 2
proto_lr: 0.01
bb_lr: 0.01
data:
batch_size: 32

View File

@ -1,50 +1,30 @@
"""Dynamically prune 'loser' prototypes in GLVQ-type models."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
CELVQ,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
num_classes = 4
num_features = 2
num_clusters = 1
train_ds = pt.datasets.Random(
num_samples=500,
train_ds = pt.datasets.Random(num_samples=500,
num_classes=num_classes,
num_features=num_features,
num_clusters=num_clusters,
separation=3.0,
seed=42,
)
seed=42)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=256)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
# Hyperparameters
prototypes_per_class = num_clusters * 5
@ -54,7 +34,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = CELVQ(
model = pt.models.CELVQ(
hparams,
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
)
@ -63,18 +43,18 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Summary
logging.info(model)
print(model)
# Callbacks
vis = VisGLVQ2D(train_ds)
pruning = PruneLoserPrototypes(
vis = pt.models.VisGLVQ2D(train_ds)
pruning = pt.models.PruneLoserPrototypes(
threshold=0.01, # prune prototype if it wins less than 1%
idle_epochs=20, # pruning too early may cause problems
prune_quota_per_epoch=2, # prune at most 2 prototypes per epoch
frequency=1, # prune every epoch
verbose=True,
)
es = EarlyStopping(
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=20,
@ -84,18 +64,17 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
es,
],
detect_anomaly=True,
log_every_n_steps=1,
max_epochs=1000,
progress_bar_refresh_rate=0,
terminate_on_nan=True,
weights_summary="full",
accelerator="ddp",
)
# Training loop

View File

@ -1,35 +1,23 @@
"""GLVQ example using the Iris dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GLVQ, VisGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=64, num_workers=4)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
@ -41,7 +29,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = GLVQ(
model = pt.models.GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
@ -53,30 +41,15 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
vis = pt.models.VisGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
max_epochs=100,
log_every_n_steps=1,
detect_anomaly=True,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./glvq_iris.ckpt")
# Load saved model
new_model = GLVQ.load_from_checkpoint(
checkpoint_path="./glvq_iris.ckpt",
strict=False,
)
logging.info(new_model)

View File

@ -1,39 +1,22 @@
"""GMLVQ example using the spiral dataset."""
"""GLVQ example using the spiral dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
GMLVQ,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=256)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
# Hyperparameters
num_classes = 2
@ -49,19 +32,19 @@ if __name__ == "__main__":
)
# Initialize the model
model = GMLVQ(
model = pt.models.GMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
)
# Callbacks
vis = VisGLVQ2D(
vis = pt.models.VisGLVQ2D(
train_ds,
show_last_only=False,
block=False,
)
pruning = PruneLoserPrototypes(
pruning = pt.models.PruneLoserPrototypes(
threshold=0.01,
idle_epochs=10,
prune_quota_per_epoch=5,
@ -70,7 +53,7 @@ if __name__ == "__main__":
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
verbose=True,
)
es = EarlyStopping(
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=1.0,
patience=5,
@ -79,18 +62,14 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
pruning,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
terminate_on_nan=True,
)
# Training loop

View File

@ -1,78 +0,0 @@
"""GMLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GMLVQ, VisGMLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
input_dim=4,
latent_dim=4,
distribution={
"num_classes": 3,
"per_class": 2
},
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = GMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 4)
# Callbacks
vis = VisGMLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
max_epochs=100,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)
torch.save(model, "iris.pth")

View File

@ -1,33 +1,17 @@
"""GMLVQ example using the MNIST dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
ImageGMLVQ,
PruneLoserPrototypes,
VisImgComp,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
@ -49,8 +33,12 @@ if __name__ == "__main__":
)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=4, batch_size=256)
test_loader = DataLoader(test_ds, num_workers=4, batch_size=256)
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=256)
test_loader = torch.utils.data.DataLoader(test_ds,
num_workers=0,
batch_size=256)
# Hyperparameters
num_classes = 10
@ -64,14 +52,14 @@ if __name__ == "__main__":
)
# Initialize the model
model = ImageGMLVQ(
model = pt.models.ImageGMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
)
# Callbacks
vis = VisImgComp(
vis = pt.models.VisImgComp(
data=train_ds,
num_columns=10,
show=False,
@ -81,14 +69,14 @@ if __name__ == "__main__":
embedding_data=200,
flatten_data=False,
)
pruning = PruneLoserPrototypes(
pruning = pt.models.PruneLoserPrototypes(
threshold=0.01,
idle_epochs=1,
prune_quota_per_epoch=10,
frequency=1,
verbose=True,
)
es = EarlyStopping(
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=15,
@ -97,18 +85,16 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
es,
# es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
terminate_on_nan=True,
weights_summary=None,
# accelerator="ddp",
)
# Training loop

View File

@ -1,33 +1,23 @@
"""Growing Neural Gas example using the Iris dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GrowingNeuralGas, VisNG2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=42)
pl.utilities.seed.seed_everything(seed=42)
# Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = DataLoader(train_ds, batch_size=64)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
@ -37,7 +27,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = GrowingNeuralGas(
model = pt.models.GrowingNeuralGas(
hparams,
prototypes_initializer=pt.initializers.ZCI(2),
)
@ -46,22 +36,17 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Model summary
logging.info(model)
print(model)
# Callbacks
vis = VisNG2D(data=train_loader)
vis = pt.models.VisNG2D(data=train_loader)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=100,
log_every_n_steps=1,
detect_anomaly=True,
callbacks=[vis],
weights_summary="full",
)
# Training loop

View File

@ -1,77 +0,0 @@
"""GMLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GRLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris([0, 1])
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
input_dim=2,
distribution={
"num_classes": 3,
"per_class": 2
},
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = GRLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisSiameseGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
max_epochs=5,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)
torch.save(model, "iris.pth")

View File

@ -1,119 +0,0 @@
"""GTLVQ example using the MNIST dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
ImageGTLVQ,
PruneLoserPrototypes,
VisImgComp,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
train_ds = MNIST(
"~/datasets",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
test_ds = MNIST(
"~/datasets",
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=256)
test_loader = DataLoader(test_ds, num_workers=0, batch_size=256)
# Hyperparameters
num_classes = 10
prototypes_per_class = 1
hparams = dict(
input_dim=28 * 28,
latent_dim=28,
distribution=(num_classes, prototypes_per_class),
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = ImageGTLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
#Use one batch of data for subspace initiator.
omega_initializer=pt.initializers.PCALinearTransformInitializer(
next(iter(train_loader))[0].reshape(256, 28 * 28)))
# Callbacks
vis = VisImgComp(
data=train_ds,
num_columns=10,
show=False,
tensorboard=True,
random_data=100,
add_embedding=True,
embedding_data=200,
flatten_data=False,
)
pruning = PruneLoserPrototypes(
threshold=0.01,
idle_epochs=1,
prune_quota_per_epoch=10,
frequency=1,
verbose=True,
)
es = EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=15,
mode="min",
check_on_train_epoch_end=True,
)
# Setup trainer
# using GPUs here is strongly recommended!
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
pruning,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -1,79 +0,0 @@
"""Localized-GTLVQ example using the Moons dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GTLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = DataLoader(
train_ds,
batch_size=256,
shuffle=True,
)
# Hyperparameters
# Latent_dim should be lower than input dim.
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1)
# Initialize the model
model = GTLVQ(hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds))
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
logging.info(model)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
es = EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
mode="max",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -1,75 +1,51 @@
"""k-NN example using the Iris dataset from scikit-learn."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import KNN, VisGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
X, y = load_iris(return_X_y=True)
X = X[:, 0:3:2]
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
test_size=0.5,
random_state=42,
)
train_ds = pt.datasets.NumpyDataset(X_train, y_train)
test_ds = pt.datasets.NumpyDataset(X_test, y_test)
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=16)
test_loader = DataLoader(test_ds, batch_size=16)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(k=5)
# Initialize the model
model = KNN(hparams, data=train_ds)
model = pt.models.KNN(hparams, data=train_ds)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
logging.info(model)
print(model)
# Callbacks
vis = VisGLVQ2D(
data=(X_train, y_train),
vis = pt.models.VisGLVQ2D(
data=(x_train, y_train),
resolution=200,
block=True,
)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=1,
callbacks=[
vis,
],
log_every_n_steps=1,
detect_anomaly=True,
callbacks=[vis],
weights_summary="full",
)
# Training loop
@ -77,8 +53,5 @@ if __name__ == "__main__":
trainer.fit(model, train_loader)
# Recall
y_pred = model.predict(torch.tensor(X_train))
logging.info(y_pred)
# Test
trainer.test(model, dataloaders=test_loader)
y_pred = model.predict(torch.tensor(x_train))
print(y_pred)

View File

@ -1,25 +1,15 @@
"""Kohonen Self Organizing Map."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from matplotlib import pyplot as plt
from prototorch.models import KohonenSOM
from prototorch.utils.colors import hex_to_rgb
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader, TensorDataset
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Vis2DColorSOM(pl.Callback):
def __init__(self, data, title="ColorSOMe", pause_time=0.1):
super().__init__()
self.title = title
@ -27,7 +17,7 @@ class Vis2DColorSOM(pl.Callback):
self.data = data
self.pause_time = pause_time
def on_train_epoch_end(self, trainer, pl_module: KohonenSOM):
def on_epoch_end(self, trainer, pl_module):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
@ -40,14 +30,12 @@ class Vis2DColorSOM(pl.Callback):
d = pl_module.compute_distances(self.data)
wp = pl_module.predict_from_distances(d)
for i, iloc in enumerate(wp):
plt.text(
iloc[1],
plt.text(iloc[1],
iloc[0],
color_names[i],
cnames[i],
ha="center",
va="center",
bbox=dict(facecolor="white", alpha=0.5, lw=0),
)
bbox=dict(facecolor="white", alpha=0.5, lw=0))
if trainer.current_epoch != trainer.max_epochs - 1:
plt.pause(self.pause_time)
@ -58,12 +46,11 @@ class Vis2DColorSOM(pl.Callback):
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=42)
pl.utilities.seed.seed_everything(seed=42)
# Prepare the data
hex_colors = [
@ -71,15 +58,15 @@ if __name__ == "__main__":
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
"#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
]
color_names = [
cnames = [
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
"red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
"lightgrey", "olive", "purple", "orange"
]
colors = list(hex_to_rgb(hex_colors))
data = torch.Tensor(colors) / 255.0
train_ds = TensorDataset(data)
train_loader = DataLoader(train_ds, batch_size=8)
train_ds = torch.utils.data.TensorDataset(data)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
# Hyperparameters
hparams = dict(
@ -90,7 +77,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = KohonenSOM(
model = pt.models.KohonenSOM(
hparams,
prototypes_initializer=pt.initializers.RNCI(3),
)
@ -99,22 +86,17 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 3)
# Model summary
logging.info(model)
print(model)
# Callbacks
vis = Vis2DColorSOM(data=data)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=500,
callbacks=[
vis,
],
log_every_n_steps=1,
detect_anomaly=True,
callbacks=[vis],
weights_summary="full",
)
# Training loop

View File

@ -1,36 +1,27 @@
"""Localized-GMLVQ example using the Moons dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import LGMLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=2)
pl.utilities.seed.seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=256,
shuffle=True)
# Hyperparameters
hparams = dict(
@ -40,7 +31,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = LGMLVQ(
model = pt.models.LGMLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
)
@ -49,11 +40,11 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Summary
logging.info(model)
print(model)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
es = EarlyStopping(
vis = pt.models.VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
@ -63,17 +54,14 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
log_every_n_steps=1,
max_epochs=1000,
detect_anomaly=True,
weights_summary="full",
accelerator="ddp",
)
# Training loop

View File

@ -1,26 +1,13 @@
"""LVQMLN example using all four dimensions of the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
LVQMLN,
PruneLoserPrototypes,
VisSiameseGLVQ2D,
)
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
@ -39,18 +26,17 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Reproducibility
seed_everything(seed=42)
pl.utilities.seed.seed_everything(seed=42)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=150)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
@ -63,7 +49,7 @@ if __name__ == "__main__":
backbone = Backbone()
# Initialize the model
model = LVQMLN(
model = pt.models.LVQMLN(
hparams,
prototypes_initializer=pt.initializers.SSCI(
train_ds,
@ -72,15 +58,18 @@ if __name__ == "__main__":
backbone=backbone,
)
# Model summary
print(model)
# Callbacks
vis = VisSiameseGLVQ2D(
vis = pt.models.VisSiameseGLVQ2D(
data=train_ds,
map_protos=False,
border=0.1,
resolution=500,
axis_off=True,
)
pruning = PruneLoserPrototypes(
pruning = pt.models.PruneLoserPrototypes(
threshold=0.01,
idle_epochs=20,
prune_quota_per_epoch=2,
@ -89,17 +78,12 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
],
log_every_n_steps=1,
max_epochs=1000,
detect_anomaly=True,
)
# Training loop

View File

@ -1,40 +1,28 @@
"""Median-LVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import MedianLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = DataLoader(
train_loader = torch.utils.data.DataLoader(
train_ds,
batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches
)
# Initialize the model
model = MedianLVQ(
model = pt.models.MedianLVQ(
hparams=dict(distribution=(3, 2), lr=0.01),
prototypes_initializer=pt.initializers.SSCI(train_ds),
)
@ -43,8 +31,8 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
es = EarlyStopping(
vis = pt.models.VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_acc",
min_delta=0.01,
patience=5,
@ -54,17 +42,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis, es],
weights_summary="full",
)
# Training loop

View File

@ -1,35 +1,23 @@
"""Neural Gas example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import NeuralGas, VisNG2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Prepare and pre-process the dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, 0:3:2]
x_train = x_train[:, [0, 2]]
scaler = StandardScaler()
scaler.fit(x_train)
x_train = scaler.transform(x_train)
@ -37,7 +25,7 @@ if __name__ == "__main__":
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=150)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
@ -47,7 +35,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = NeuralGas(
model = pt.models.NeuralGas(
hparams,
prototypes_initializer=pt.core.ZCI(2),
lr_scheduler=ExponentialLR,
@ -57,20 +45,17 @@ if __name__ == "__main__":
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Model summary
print(model)
# Callbacks
vis = VisNG2D(data=train_ds)
vis = pt.models.VisNG2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
weights_summary="full",
)
# Training loop

View File

@ -1,34 +1,25 @@
"""RSLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import RSLVQ, VisGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=42)
pl.utilities.seed.seed_everything(seed=42)
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=64)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
@ -42,7 +33,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = RSLVQ(
model = pt.models.RSLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
@ -51,20 +42,19 @@ if __name__ == "__main__":
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
vis = pt.models.VisGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
detect_anomaly=True,
max_epochs=100,
log_every_n_steps=1,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
terminate_on_nan=True,
weights_summary="full",
accelerator="ddp",
)
# Training loop

View File

@ -1,22 +1,13 @@
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
@ -35,50 +26,46 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Reproducibility
seed_everything(seed=2)
pl.utilities.seed.seed_everything(seed=2)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=150)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
distribution=[1, 2, 3],
lr=0.01,
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the backbone
backbone = Backbone()
# Initialize the model
model = SiameseGLVQ(
model = pt.models.SiameseGLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone,
both_path_gradients=False,
)
# Model summary
print(model)
# Callbacks
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
)
# Training loop

View File

@ -1,87 +0,0 @@
"""Siamese GTLVQ example using all four dimensions of the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.latent_size = latent_size
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
self.activation = torch.nn.Sigmoid()
def forward(self, x):
x = self.activation(self.dense1(x))
out = self.activation(self.dense2(x))
return out
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Reproducibility
seed_everything(seed=2)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
distribution=[1, 2, 3],
lr=0.01,
input_dim=2,
latent_dim=1,
)
# Initialize the backbone
backbone = Backbone(latent_size=hparams["input_dim"])
# Initialize the model
model = SiameseGTLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone,
both_path_gradients=False,
)
# Callbacks
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -1,42 +1,24 @@
"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
GLVQ,
KNN,
GrowingNeuralGas,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = DataLoader(train_ds, batch_size=64, num_workers=0)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Initialize the gng
gng = GrowingNeuralGas(
gng = pt.models.GrowingNeuralGas(
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
prototypes_initializer=pt.initializers.ZCI(2),
lr_scheduler=ExponentialLR,
@ -44,7 +26,7 @@ if __name__ == "__main__":
)
# Callbacks
es = EarlyStopping(
es = pl.callbacks.EarlyStopping(
monitor="loss",
min_delta=0.001,
patience=20,
@ -55,14 +37,9 @@ if __name__ == "__main__":
# Setup trainer for GNG
trainer = pl.Trainer(
accelerator="cpu",
max_epochs=50 if args.fast_dev_run else
1000, # 10 epochs fast dev run reproducible DIV error.
callbacks=[
es,
],
log_every_n_steps=1,
detect_anomaly=True,
max_epochs=100,
callbacks=[es],
weights_summary=None,
)
# Training loop
@ -75,12 +52,12 @@ if __name__ == "__main__":
)
# Warm-start prototypes
knn = KNN(dict(k=1), data=train_ds)
knn = pt.models.KNN(dict(k=1), data=train_ds)
prototypes = gng.prototypes
plabels = knn.predict(prototypes)
# Initialize the model
model = GLVQ(
model = pt.models.GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.LCI(prototypes),
@ -93,15 +70,15 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
pruning = PruneLoserPrototypes(
vis = pt.models.VisGLVQ2D(data=train_ds)
pruning = pt.models.PruneLoserPrototypes(
threshold=0.02,
idle_epochs=2,
prune_quota_per_epoch=5,
frequency=1,
verbose=True,
)
es = EarlyStopping(
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=10,
@ -111,18 +88,15 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
weights_summary="full",
accelerator="ddp",
)
# Training loop

View File

@ -1,5 +1,7 @@
"""`models` plugin for the `prototorch` package."""
from importlib.metadata import PackageNotFoundError, version
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC
from .glvq import (
@ -8,32 +10,18 @@ from .glvq import (
GLVQ21,
GMLVQ,
GRLVQ,
GTLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
ImageGTLVQ,
SiameseGLVQ,
SiameseGMLVQ,
SiameseGTLVQ,
)
from .knn import KNN
from .lvq import (
LVQ1,
LVQ21,
MedianLVQ,
)
from .probabilistic import (
CELVQ,
RSLVQ,
SLVQ,
)
from .unsupervised import (
GrowingNeuralGas,
KohonenSOM,
NeuralGas,
)
from .lvq import LVQ1, LVQ21, MedianLVQ
from .nam import BinaryNAM
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
from .vis import *
__version__ = "0.7.1"
__version__ = "0.2.0"

View File

@ -1,29 +1,21 @@
"""Abstract classes to be inherited by prototorch models."""
import logging
from typing import Final, final
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics
from prototorch.core.competitions import WTAC
from prototorch.core.components import (
AbstractComponents,
Components,
LabeledComponents,
)
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
LabelsInitializer,
ZerosCompInitializer,
)
from prototorch.core.pooling import stratified_min_pooling
from prototorch.nn.wrappers import LambdaLayer
from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents
from ..core.distances import euclidean_distance
from ..core.initializers import LabelsInitializer
from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts."""
def __init__(self, hparams, **kwargs):
super().__init__()
@ -39,7 +31,7 @@ class ProtoTorchBolt(pl.LightningModule):
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams["lr"])
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
if self.lr_scheduler is not None:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
@ -51,11 +43,9 @@ class ProtoTorchBolt(pl.LightningModule):
else:
return optimizer
@final
def reconfigure_optimizers(self):
if self.trainer:
self.trainer.strategy.setup_optimizers(self.trainer)
else:
logging.warning("No trainer to reconfigure optimizers!")
self.trainer.accelerator.setup_optimizers(self.trainer)
def __repr__(self):
surep = super().__repr__()
@ -65,13 +55,11 @@ class ProtoTorchBolt(pl.LightningModule):
class PrototypeModel(ProtoTorchBolt):
proto_layer: AbstractComponents
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn, name="distance_fn")
self.distance_layer = LambdaLayer(distance_fn)
@property
def num_prototypes(self):
@ -88,18 +76,14 @@ class PrototypeModel(ProtoTorchBolt):
def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
self.hparams["distribution"] = self.proto_layer.distribution
self.reconfigure_optimizers()
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.hparams["distribution"] = self.proto_layer.distribution
self.reconfigure_optimizers()
class UnsupervisedPrototypeModel(PrototypeModel):
proto_layer: Components
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@ -107,12 +91,12 @@ class UnsupervisedPrototypeModel(PrototypeModel):
prototypes_initializer = kwargs.get("prototypes_initializer", None)
if prototypes_initializer is not None:
self.proto_layer = Components(
self.hparams["num_prototypes"],
self.hparams.num_prototypes,
initializer=prototypes_initializer,
)
def compute_distances(self, x):
protos = self.proto_layer().type_as(x)
protos = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
@ -122,34 +106,19 @@ class UnsupervisedPrototypeModel(PrototypeModel):
class SupervisedPrototypeModel(PrototypeModel):
proto_layer: LabeledComponents
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
distribution = hparams.get("distribution", None)
prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer())
if not skip_proto_layer:
# when subclasses do not need a customized prototype layer
if prototypes_initializer is not None:
# when building a new model
self.proto_layer = LabeledComponents(
distribution=distribution,
distribution=self.hparams.distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
proto_shape = self.proto_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
)
self.competition_layer = WTAC()
@property
@ -167,14 +136,14 @@ class SupervisedPrototypeModel(PrototypeModel):
def forward(self, x):
distances = self.compute_distances(x)
_, plabels = self.proto_layer()
plabels = self.proto_layer.labels
winning = stratified_min_pooling(distances, plabels)
y_pred = F.softmin(winning, dim=1)
y_pred = torch.nn.functional.softmin(winning)
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
_, plabels = self.proto_layer()
plabels = self.proto_layer.labels
y_pred = self.competition_layer(distances, plabels)
return y_pred
@ -186,57 +155,36 @@ class SupervisedPrototypeModel(PrototypeModel):
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(
preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
self.log(
tag,
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
def test_step(self, batch, batch_idx):
x, targets = batch
preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(
preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log("test_acc", accuracy)
logger=True)
class ProtoTorchMixin:
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
pass
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
self.automatic_optimization: Final = False
def training_step(self, train_batch, batch_idx):
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx):
@final
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)

View File

@ -1,30 +1,24 @@
"""Lightning Callbacks."""
import logging
from typing import TYPE_CHECKING
import pytorch_lightning as pl
import torch
from prototorch.core.initializers import LiteralCompInitializer
from ..core.components import Components
from ..core.initializers import LiteralCompInitializer
from .extras import ConnectionTopology
if TYPE_CHECKING:
from prototorch.models import GLVQ, GrowingNeuralGas
class PruneLoserPrototypes(pl.Callback):
def __init__(
self,
def __init__(self,
threshold=0.01,
idle_epochs=10,
prune_quota_per_epoch=-1,
frequency=1,
replace=False,
prototypes_initializer=None,
verbose=False,
):
verbose=False):
self.threshold = threshold # minimum win ratio
self.idle_epochs = idle_epochs # epochs to wait before pruning
self.prune_quota_per_epoch = prune_quota_per_epoch
@ -33,7 +27,7 @@ class PruneLoserPrototypes(pl.Callback):
self.verbose = verbose
self.prototypes_initializer = prototypes_initializer
def on_train_epoch_end(self, trainer, pl_module: "GLVQ"):
def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.idle_epochs:
return None
if (trainer.current_epoch + 1) % self.frequency:
@ -48,44 +42,41 @@ class PruneLoserPrototypes(pl.Callback):
prune_labels = prune_labels[:self.prune_quota_per_epoch]
if len(to_prune) > 0:
logging.debug(f"\nPrototype win ratios: {ratios}")
logging.debug(f"Pruning prototypes at: {to_prune}")
logging.debug(f"Corresponding labels are: {prune_labels.tolist()}")
if self.verbose:
print(f"\nPrototype win ratios: {ratios}")
print(f"Pruning prototypes at: {to_prune}")
print(f"Corresponding labels are: {prune_labels.tolist()}")
cur_num_protos = pl_module.num_prototypes
pl_module.remove_prototypes(indices=to_prune)
if self.replace:
labels, counts = torch.unique(prune_labels,
sorted=True,
return_counts=True)
distribution = dict(zip(labels.tolist(), counts.tolist()))
logging.info(f"Re-adding pruned prototypes...")
logging.debug(f"distribution={distribution}")
if self.verbose:
print(f"Re-adding pruned prototypes...")
print(f"{distribution=}")
pl_module.add_prototypes(
distribution=distribution,
components_initializer=self.prototypes_initializer)
new_num_protos = pl_module.num_prototypes
logging.info(f"`num_prototypes` changed from {cur_num_protos} "
if self.verbose:
print(f"`num_prototypes` changed from {cur_num_protos} "
f"to {new_num_protos}.")
return True
class PrototypeConvergence(pl.Callback):
def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False):
self.min_delta = min_delta
self.idle_epochs = idle_epochs # epochs to wait
self.verbose = verbose
def on_train_epoch_end(self, trainer, pl_module):
def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.idle_epochs:
return None
logging.info("Stopping...")
if self.verbose:
print("Stopping...")
# TODO
return True
@ -98,21 +89,16 @@ class GNGCallback(pl.Callback):
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
"""
def __init__(self, reduction=0.1, freq=10):
self.reduction = reduction
self.freq = freq
def on_train_epoch_end(
self,
trainer: pl.Trainer,
pl_module: "GrowingNeuralGas",
):
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
if (trainer.current_epoch + 1) % self.freq == 0:
# Get information
errors = pl_module.errors
topology: ConnectionTopology = pl_module.topology_layer
components = pl_module.proto_layer.components
components: Components = pl_module.proto_layer.components
# Insertion point
worst = torch.argmax(errors)
@ -132,9 +118,8 @@ class GNGCallback(pl.Callback):
# Add component
pl_module.proto_layer.add_components(
1,
initializer=LiteralCompInitializer(new_component.unsqueeze(0)),
)
None,
initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
# Adjust Topology
topology.add_prototype()
@ -149,4 +134,4 @@ class GNGCallback(pl.Callback):
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.strategy.setup_optimizers(trainer)
trainer.accelerator_backend.setup_optimizers(trainer)

View File

@ -1,21 +1,20 @@
import torch
import torchmetrics
from prototorch.core.competitions import CBCC
from prototorch.core.components import ReasoningComponents
from prototorch.core.initializers import RandomReasoningsInitializer
from prototorch.core.losses import MarginLoss
from prototorch.core.similarities import euclidean_similarity
from prototorch.nn.wrappers import LambdaLayer
from ..core.competitions import CBCC
from ..core.components import ReasoningComponents
from ..core.initializers import RandomReasoningsInitializer
from ..core.losses import MarginLoss
from ..core.similarities import euclidean_similarity
from ..nn.wrappers import LambdaLayer
from .abstract import ImagePrototypesMixin
from .glvq import SiameseGLVQ
class CBC(SiameseGLVQ):
"""Classification-By-Components."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, skip_proto_layer=True, **kwargs)
super().__init__(hparams, **kwargs)
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
components_initializer = kwargs.get("components_initializer", None)
@ -44,7 +43,7 @@ class CBC(SiameseGLVQ):
probs = self.competition_layer(detections, reasonings)
return probs
def shared_step(self, batch, batch_idx):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
y_pred = self(x)
num_classes = self.num_classes
@ -52,23 +51,17 @@ class CBC(SiameseGLVQ):
loss = self.loss(y_pred, y_true).mean()
return y_pred, loss
def training_step(self, batch, batch_idx):
y_pred, train_loss = self.shared_step(batch, batch_idx)
def training_step(self, batch, batch_idx, optimizer_idx=None):
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
preds = torch.argmax(y_pred, dim=1)
accuracy = torchmetrics.functional.accuracy(
preds.int(),
batch[1].int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(
"train_acc",
accuracy = torchmetrics.functional.accuracy(preds.int(),
batch[1].int())
self.log("train_acc",
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
logger=True)
return train_loss
def predict(self, x):

123
prototorch/models/data.py Normal file
View File

@ -0,0 +1,123 @@
"""Prototorch Data Modules
This allows to store the used dataset inside a Lightning Module.
Mainly used for PytorchLightningCLI configurations.
"""
from typing import Any, Optional, Type
import prototorch as pt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
# MNIST
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
# Download mnist dataset as side-effect, only called on the first cpu
def prepare_data(self):
MNIST("~/datasets", train=True, download=True)
MNIST("~/datasets", train=False, download=True)
# called for every GPU/machine (assigning state is OK)
def setup(self, stage=None):
# Transforms
transform = transforms.Compose([
transforms.ToTensor(),
])
# Split dataset
if stage in (None, "fit"):
mnist_train = MNIST("~/datasets", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(
mnist_train,
[55000, 5000],
)
if stage == (None, "test"):
self.mnist_test = MNIST(
"~/datasets",
train=False,
transform=transform,
)
# Dataloaders
def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return mnist_train
def val_dataloader(self):
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
return mnist_val
def test_dataloader(self):
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
return mnist_test
# def train_on_mnist(batch_size=256) -> type:
# class DataClass(pl.LightningModule):
# datamodule = MNISTDataModule(batch_size=batch_size)
# def __init__(self, *args, **kwargs):
# prototype_initializer = kwargs.pop(
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# dc: Type[DataClass] = DataClass
# return dc
# ABSTRACT
class GeneralDataModule(pl.LightningDataModule):
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
super().__init__()
self.train_dataset = dataset
self.batch_size = batch_size
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset, batch_size=self.batch_size)
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
# class DataClass(pl.LightningModule):
# datamodule = GeneralDataModule(dataset, batch_size)
# datashape = dataset[0][0].shape
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
# def __init__(self, *args: Any, **kwargs: Any) -> None:
# prototype_initializer = kwargs.pop(
# "prototype_initializer",
# pt.components.Zeros(self.datashape),
# )
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# return DataClass
# if __name__ == "__main__":
# from prototorch.models import GLVQ
# demo_dataset = pt.datasets.Iris()
# TrainingClass: Type = train_on_dataset(demo_dataset)
# class DemoGLVQ(TrainingClass, GLVQ):
# """Model Definition."""
# # Hyperparameters
# hparams = dict(
# distribution={
# "num_classes": 3,
# "prototypes_per_class": 4
# },
# lr=0.01,
# )
# initialized = DemoGLVQ(hparams)
# print(initialized)

View File

@ -5,7 +5,8 @@ Modules not yet available in prototorch go here temporarily.
"""
import torch
from prototorch.core.similarities import gaussian
from ..core.similarities import gaussian
def rank_scaled_gaussian(distances, lambd):
@ -14,46 +15,7 @@ def rank_scaled_gaussian(distances, lambd):
return torch.exp(-torch.exp(-ranks / lambd) * distances)
def orthogonalization(tensors):
"""Orthogonalization via polar decomposition """
u, _, v = torch.svd(tensors, compute_uv=True)
u_shape = tuple(list(u.shape))
v_shape = tuple(list(v.shape))
# reshape to (num x N x M)
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
out = u @ v.permute([0, 2, 1])
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
return out
def ltangent_distance(x, y, omegas):
r"""Localized Tangent distance.
Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
:param `torch.tensor` omegas: Three dimensional matrix
:rtype: `torch.tensor`
"""
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
omegas, omegas.permute([0, 2, 1]))
projected_x = x @ p
projected_y = torch.diagonal(y @ p).T
expanded_y = torch.unsqueeze(projected_y, dim=1)
batchwise_difference = expanded_y - projected_x
differences_squared = batchwise_difference**2
distances = torch.sqrt(torch.sum(differences_squared, dim=2))
distances = distances.permute(1, 0)
return distances
class GaussianPrior(torch.nn.Module):
def __init__(self, variance):
super().__init__()
self.variance = variance
@ -63,7 +25,6 @@ class GaussianPrior(torch.nn.Module):
class RankScaledGaussianPrior(torch.nn.Module):
def __init__(self, lambd):
super().__init__()
self.lambd = lambd
@ -73,7 +34,6 @@ class RankScaledGaussianPrior(torch.nn.Module):
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()
self.agelimit = agelimit

View File

@ -1,29 +1,19 @@
"""Models based on the GLVQ framework."""
import torch
from prototorch.core.competitions import wtac
from prototorch.core.distances import (
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.core.initializers import EyeLinearTransformInitializer
from prototorch.core.losses import (
GLVQLoss,
lvq1_loss,
lvq21_loss,
)
from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from ..core.initializers import EyeTransformInitializer
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
from ..core.transforms import LinearTransform
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization
class GLVQ(SupervisedPrototypeModel):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@ -34,21 +24,17 @@ class GLVQ(SupervisedPrototypeModel):
# Loss
self.loss = GLVQLoss(
margin=self.hparams["margin"],
transfer_fn=self.hparams["transfer_fn"],
beta=self.hparams["transfer_beta"],
margin=self.hparams.margin,
transfer_fn=self.hparams.transfer_fn,
beta=self.hparams.transfer_beta,
)
# def on_save_checkpoint(self, checkpoint):
# if "prototype_win_ratios" in checkpoint["state_dict"]:
# del checkpoint["state_dict"]["prototype_win_ratios"]
def initialize_prototype_win_ratios(self):
self.register_buffer(
"prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device))
def on_train_epoch_start(self):
def on_epoch_start(self):
self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances):
@ -66,15 +52,15 @@ class GLVQ(SupervisedPrototypeModel):
prototype_wr,
])
def shared_step(self, batch, batch_idx):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x)
_, plabels = self.proto_layer()
plabels = self.proto_layer.labels
loss = self.loss(out, y, plabels)
return out, loss
def training_step(self, batch, batch_idx):
out, train_loss = self.shared_step(batch, batch_idx)
def training_step(self, batch, batch_idx, optimizer_idx=None):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log_prototype_win_ratios(out)
self.log("train_loss", train_loss)
self.log_acc(out, batch[-1], tag="train_acc")
@ -99,6 +85,10 @@ class GLVQ(SupervisedPrototypeModel):
test_loss += batch_loss.item()
self.log("test_loss", test_loss)
# TODO
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting.
@ -108,7 +98,6 @@ class SiameseGLVQ(GLVQ):
transformation pipeline are only learned from the inputs.
"""
def __init__(self,
hparams,
backbone=torch.nn.Identity(),
@ -119,17 +108,32 @@ class SiameseGLVQ(GLVQ):
self.backbone = backbone
self.both_path_gradients = both_path_gradients
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams.proto_lr)
# Only add a backbone optimizer if backbone has trainable parameters
if (bb_params := list(self.backbone.parameters())):
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
optimizers = [proto_opt, bb_opt]
else:
optimizers = [proto_opt]
if self.lr_scheduler is not None:
schedulers = []
for optimizer in optimizers:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
schedulers.append(scheduler)
return optimizers, schedulers
else:
return optimizers
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x)
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(bb_grad)
self.backbone.requires_grad_(True)
distances = self.distance_layer(latent_x, latent_protos)
return distances
@ -159,7 +163,6 @@ class LVQMLN(SiameseGLVQ):
rather in the embedding space.
"""
def compute_distances(self, x):
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
@ -175,22 +178,17 @@ class GRLVQ(SiameseGLVQ):
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
"""
_relevances: torch.Tensor
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Additional parameters
relevances = torch.ones(self.hparams["input_dim"], device=self.device)
relevances = torch.ones(self.hparams.input_dim, device=self.device)
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone
self.backbone = LambdaLayer(self._apply_relevances,
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
name="relevance scaling")
def _apply_relevances(self, x):
return x @ torch.diag(self._relevances)
@property
def relevance_profile(self):
return self._relevances.detach().cpu()
@ -205,16 +203,15 @@ class SiameseGMLVQ(SiameseGLVQ):
Implemented as a Siamese network with a linear transformation backbone.
"""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Override the backbone
omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer())
EyeTransformInitializer())
self.backbone = LinearTransform(
self.hparams["input_dim"],
self.hparams["latent_dim"],
self.hparams.input_dim,
self.hparams.output_dim,
initializer=omega_initializer,
)
@ -224,7 +221,7 @@ class SiameseGMLVQ(SiameseGLVQ):
@property
def lambda_matrix(self):
omega = self.backbone.weights # (input_dim, latent_dim)
omega = self.backbone.weight # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
@ -236,31 +233,23 @@ class GMLVQ(GLVQ):
function. This makes it easier to implement a localized variant.
"""
# Parameters
_omega: torch.Tensor
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer())
omega = omega_initializer.generate(self.hparams["input_dim"],
self.hparams["latent_dim"])
EyeTransformInitializer())
omega = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim)
self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega,
name="omega matrix")
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach() # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega)
@ -272,7 +261,6 @@ class GMLVQ(GLVQ):
class LGMLVQ(GMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", lomega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
@ -280,59 +268,15 @@ class LGMLVQ(GMLVQ):
# Re-register `_omega` to override the one from the super class.
omega = torch.randn(
self.num_prototypes,
self.hparams["input_dim"],
self.hparams["latent_dim"],
self.hparams.input_dim,
self.hparams.latent_dim,
device=self.device,
)
self.register_parameter("_omega", Parameter(omega))
class GTLVQ(LGMLVQ):
"""Localized and Generalized Tangent Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
omega_initializer = kwargs.get("omega_initializer")
if omega_initializer is not None:
subspace = omega_initializer.generate(
self.hparams["input_dim"],
self.hparams["latent_dim"],
)
omega = torch.repeat_interleave(
subspace.unsqueeze(0),
self.num_prototypes,
dim=0,
)
else:
omega = torch.rand(
self.num_prototypes,
self.hparams["input_dim"],
self.hparams["latent_dim"],
device=self.device,
)
# Re-register `_omega` to override the one from the super class.
self.register_parameter("_omega", Parameter(omega))
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega))
class SiameseGTLVQ(SiameseGLVQ, GTLVQ):
"""Generalized Tangent Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
"""
class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = LossLayer(lvq1_loss)
@ -341,7 +285,6 @@ class GLVQ1(GLVQ):
class GLVQ21(GLVQ):
"""Generalized Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = LossLayer(lvq21_loss)
@ -364,18 +307,3 @@ class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
after updates.
"""
class ImageGTLVQ(ImagePrototypesMixin, GTLVQ):
"""GTLVQ for training on image data.
GTLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega))

View File

@ -2,22 +2,17 @@
import warnings
from prototorch.core.competitions import KNNC
from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
LiteralCompInitializer,
LiteralLabelsInitializer,
)
from prototorch.utils.utils import parse_data_arg
from ..core.competitions import KNNC
from ..core.components import LabeledComponents
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
from ..utils.utils import parse_data_arg
from .abstract import SupervisedPrototypeModel
class KNN(SupervisedPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, skip_proto_layer=True, **kwargs)
super().__init__(hparams, **kwargs)
# Default hparams
self.hparams.setdefault("k", 1)
@ -29,15 +24,18 @@ class KNN(SupervisedPrototypeModel):
# Layers
self.proto_layer = LabeledComponents(
distribution=len(data) * [1],
distribution=[],
components_initializer=LiteralCompInitializer(data),
labels_initializer=LiteralLabelsInitializer(targets))
self.competition_layer = KNNC(k=self.hparams.k)
def training_step(self, train_batch, batch_idx):
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1 # skip training step
def on_train_batch_start(self, train_batch, batch_idx):
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1

View File

@ -1,20 +1,18 @@
"""LVQ models that are optimized using non-gradient methods."""
import logging
from prototorch.core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer
from ..core.losses import _get_dp_dm
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin
from .glvq import GLVQ
class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
def training_step(self, train_batch, batch_idx):
protos, plables = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)
# TODO Vectorized implementation
@ -32,8 +30,8 @@ class LVQ1(NonGradientMixin, GLVQ):
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
logging.debug(f"dis={dis}")
logging.debug(f"y={y}")
print(f"{dis=}")
print(f"{y=}")
# Logging
self.log_acc(dis, y, tag="train_acc")
@ -42,9 +40,9 @@ class LVQ1(NonGradientMixin, GLVQ):
class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx):
protos, plabels = self.proto_layer()
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
x, y = train_batch
dis = self.compute_distances(x)
@ -75,8 +73,8 @@ class MedianLVQ(NonGradientMixin, GLVQ):
# TODO Avoid computing distances over and over
"""
def __init__(self, hparams, **kwargs):
def __init__(self, hparams, verbose=True, **kwargs):
self.verbose = verbose
super().__init__(hparams, **kwargs)
self.transfer_layer = LambdaLayer(
@ -100,8 +98,9 @@ class MedianLVQ(NonGradientMixin, GLVQ):
lower_bound = (gamma * f.log()).sum()
return lower_bound
def training_step(self, train_batch, batch_idx):
protos, plabels = self.proto_layer()
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
x, y = train_batch
dis = self.compute_distances(x)
@ -117,7 +116,8 @@ class MedianLVQ(NonGradientMixin, GLVQ):
_protos[i] = xk
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
if _lower_bound > lower_bound:
logging.debug(f"Updating prototype {i} to data {k}...")
if self.verbose:
print(f"Updating prototype {i} to data {k}...")
self.proto_layer.load_state_dict({"_components": _protos},
strict=False)
break

58
prototorch/models/nam.py Normal file
View File

@ -0,0 +1,58 @@
"""ProtoTorch Neural Additive Model."""
import torch
import torchmetrics
from .abstract import ProtoTorchBolt
class BinaryNAM(ProtoTorchBolt):
"""Neural Additive Model for binary classification.
Paper: https://arxiv.org/abs/2004.13912
Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models
"""
def __init__(self, hparams: dict, extractors: torch.nn.ModuleList,
**kwargs):
super().__init__(hparams, **kwargs)
# Default hparams
self.hparams.setdefault("threshold", 0.5)
self.extractors = extractors
self.linear = torch.nn.Linear(in_features=len(extractors),
out_features=1,
bias=True)
def extract(self, x):
"""Apply the local extractors batch-wise on features."""
out = torch.zeros_like(x)
for j in range(x.shape[1]):
out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze()
return out
def forward(self, x):
x = self.extract(x)
x = self.linear(x)
return torch.sigmoid(x)
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
preds = self(x).squeeze()
train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float())
self.log("train_loss", train_loss)
accuracy = torchmetrics.functional.accuracy(preds.int(), y.int())
self.log("train_acc",
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
return train_loss
def predict(self, x):
out = self(x)
pred = torch.zeros_like(out, device=self.device)
pred[out > self.hparams.threshold] = 1
return pred

View File

@ -1,30 +1,25 @@
"""Probabilistic GLVQ methods"""
import torch
from prototorch.core.losses import nllr_loss, rslvq_loss
from prototorch.core.pooling import (
stratified_min_pooling,
stratified_sum_pooling,
)
from prototorch.nn.wrappers import LossLayer
from ..core.losses import nllr_loss, rslvq_loss
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
from ..nn.wrappers import LambdaLayer, LossLayer
from .extras import GaussianPrior, RankScaledGaussianPrior
from .glvq import GLVQ, SiameseGMLVQ
class CELVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Loss
self.loss = torch.nn.CrossEntropyLoss()
def shared_step(self, batch, batch_idx):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x) # [None, num_protos]
_, plabels = self.proto_layer()
plabels = self.proto_layer.labels
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning
batch_loss = self.loss(probs, y.long())
@ -33,28 +28,20 @@ class CELVQ(GLVQ):
class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
self.rejection_confidence = rejection_confidence
self._conditional_distribution = None
def forward(self, x):
distances = self.compute_distances(x)
conditional = self.conditional_distribution(distances)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
device=self.device)
posterior = conditional * prior
plabels = self.proto_layer._labels
if isinstance(plabels, torch.LongTensor) or isinstance(
plabels, torch.cuda.LongTensor): # type: ignore
y_pred = stratified_sum_pooling(posterior, plabels) # type: ignore
else:
raise ValueError("Labels must be LongTensor.")
y_pred = stratified_sum_pooling(posterior, plabels)
return y_pred
def predict(self, x):
@ -63,46 +50,27 @@ class ProbabilisticLVQ(GLVQ):
prediction[confidence < self.rejection_confidence] = -1
return prediction
def training_step(self, batch, batch_idx):
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.forward(x)
_, plabels = self.proto_layer()
plabels = self.proto_layer.labels
batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum()
return loss
def conditional_distribution(self, distances):
"""Conditional distribution of distances."""
if self._conditional_distribution is None:
raise ValueError("Conditional distribution is not set.")
return self._conditional_distribution(distances)
train_loss = batch_loss.sum()
self.log("train_loss", train_loss)
return train_loss
class SLVQ(ProbabilisticLVQ):
"""Soft Learning Vector Quantization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self._conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(nllr_loss)
class RSLVQ(ProbabilisticLVQ):
"""Robust Soft Learning Vector Quantization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self._conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(rslvq_loss)
@ -111,19 +79,14 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
TODO: Use Backbone LVQ instead
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("lambda", 1.0)
lam = self.hparams.get("lambda", 1.0)
self.conditional_distribution = RankScaledGaussianPrior(lam)
self.conditional_distribution = RankScaledGaussianPrior(
self.hparams.lambd)
self.loss = torch.nn.KLDivLoss()
# FIXME
# def training_step(self, batch, batch_idx):
# def training_step(self, batch, batch_idx, optimizer_idx=None):
# x, y = batch
# y_pred = self(x)
# batch_loss = self.loss(y_pred, y)

View File

@ -2,10 +2,11 @@
import numpy as np
import torch
from prototorch.core.competitions import wtac
from prototorch.core.distances import squared_euclidean_distance
from prototorch.core.losses import NeuralGasEnergy
from ..core.competitions import wtac
from ..core.distances import squared_euclidean_distance
from ..core.losses import NeuralGasEnergy
from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
from .callbacks import GNGCallback
from .extras import ConnectionTopology
@ -17,8 +18,6 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
TODO Allow non-2D grids
"""
_grid: torch.Tensor
def __init__(self, hparams, **kwargs):
h, w = hparams.get("shape")
# Ignore `num_prototypes`
@ -35,7 +34,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
# Additional parameters
x, y = torch.arange(h), torch.arange(w)
grid = torch.stack(torch.meshgrid(x, y, indexing="ij"), dim=-1)
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
self.register_buffer("_grid", grid)
self._sigma = self.hparams.sigma
self._lr = self.hparams.lr
@ -54,16 +53,14 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
grid = self._grid.view(-1, 2)
gd = squared_euclidean_distance(wp, grid)
nh = torch.exp(-gd / self._sigma**2)
protos = self.proto_layer()
protos = self.proto_layer.components
diff = x.unsqueeze(dim=1) - protos
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
updated_protos = protos + delta.sum(dim=0)
self.proto_layer.load_state_dict(
{"_components": updated_protos},
strict=False,
)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
def on_training_epoch_end(self, training_step_outputs):
def training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp(
-self.current_epoch / self.trainer.max_epochs)
@ -72,7 +69,6 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
class HeskesSOM(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@ -82,7 +78,6 @@ class HeskesSOM(UnsupervisedPrototypeModel):
class NeuralGas(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@ -90,13 +85,13 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("age_limit", 10)
self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1)
self.energy_layer = NeuralGasEnergy(lm=self.hparams["lm"])
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.topology_layer = ConnectionTopology(
agelimit=self.hparams["age_limit"],
num_prototypes=self.hparams["num_prototypes"],
agelimit=self.hparams.agelimit,
num_prototypes=self.hparams.num_prototypes,
)
def training_step(self, train_batch, batch_idx):
@ -109,10 +104,12 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.log("loss", loss)
return loss
# def training_epoch_end(self, training_step_outputs):
# print(f"{self.trainer.lr_schedulers}")
# print(f"{self.trainer.lr_schedulers[0]['scheduler'].optimizer}")
class GrowingNeuralGas(NeuralGas):
errors: torch.Tensor
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@ -121,10 +118,7 @@ class GrowingNeuralGas(NeuralGas):
self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10)
errors = torch.zeros(
self.hparams["num_prototypes"],
device=self.device,
)
errors = torch.zeros(self.hparams.num_prototypes, device=self.device)
self.register_buffer("errors", errors)
def training_step(self, train_batch, _batch_idx):
@ -139,7 +133,7 @@ class GrowingNeuralGas(NeuralGas):
dp = d * mask
self.errors += torch.sum(dp * dp)
self.errors *= self.hparams["step_reduction"]
self.errors *= self.hparams.step_reduction
self.topology_layer(d)
self.log("loss", loss)
@ -147,8 +141,6 @@ class GrowingNeuralGas(NeuralGas):
def configure_callbacks(self):
return [
GNGCallback(
reduction=self.hparams["insert_reduction"],
freq=self.hparams["insert_freq"],
)
GNGCallback(reduction=self.hparams.insert_reduction,
freq=self.hparams.insert_freq)
]

View File

@ -1,28 +1,20 @@
"""Visualization Callbacks."""
import warnings
from typing import Sized
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from matplotlib import pyplot as plt
from prototorch.utils.colors import get_colors, get_legend_handles
from prototorch.utils.utils import mesh2d
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, Dataset
from ..utils.utils import mesh2d
class Vis2DAbstract(pl.Callback):
def __init__(self,
data=None,
data,
title="Prototype Visualization",
cmap="viridis",
xlabel="Data dimension 1",
ylabel="Data dimension 2",
legend_labels=None,
border=0.1,
resolution=100,
flatten_data=True,
@ -35,15 +27,9 @@ class Vis2DAbstract(pl.Callback):
block=False):
super().__init__()
if data:
if isinstance(data, Dataset):
if isinstance(data, Sized):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
else:
# TODO: Add support for non-sized datasets
raise NotImplementedError(
"Data must be a dataset with a __len__ method.")
elif isinstance(data, DataLoader):
elif isinstance(data, torch.utils.data.DataLoader):
x = torch.tensor([])
y = torch.tensor([])
for x_b, y_b in data:
@ -57,14 +43,8 @@ class Vis2DAbstract(pl.Callback):
self.x_train = x
self.y_train = y
else:
self.x_train = None
self.y_train = None
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
self.legend_labels = legend_labels
self.fig = plt.figure(self.title)
self.cmap = cmap
self.border = border
@ -83,12 +63,14 @@ class Vis2DAbstract(pl.Callback):
return False
return True
def setup_ax(self):
def setup_ax(self, xlabel=None, ylabel=None):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.ylabel)
if xlabel:
ax.set_xlabel("Data dimension 1")
if ylabel:
ax.set_ylabel("Data dimension 2")
if self.axis_off:
ax.axis("off")
return ax
@ -131,47 +113,60 @@ class Vis2DAbstract(pl.Callback):
else:
plt.show(block=self.block)
def on_train_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
self.visualize(pl_module)
self.log_and_display(trainer, pl_module)
def on_train_end(self, trainer, pl_module):
plt.close()
def visualize(self, pl_module):
raise NotImplementedError
class Vis2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
mesh_input = torch.from_numpy(mesh_input).type_as(x_train)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None:
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
self.border, self.resolution)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisSiameseGLVQ2D(Vis2DAbstract):
def __init__(self, *args, map_protos=True, **kwargs):
super().__init__(*args, **kwargs)
self.map_protos = map_protos
def visualize(self, pl_module):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
@ -198,42 +193,18 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
self.log_and_display(trainer, pl_module)
class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
ax = self.setup_ax()
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos))
@ -245,15 +216,20 @@ class VisCBC2D(Vis2DAbstract):
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax()
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
@ -267,27 +243,10 @@ class VisNG2D(Vis2DAbstract):
"k-",
)
class VisSpectralProtos(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
ax = self.setup_ax()
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
for p, pl in zip(protos, plabels):
ax.plot(p, c=colors[int(pl)])
if self.legend_labels:
handles = get_legend_handles(
colors,
self.legend_labels,
marker="lines",
)
ax.legend(handles=handles)
self.log_and_display(trainer, pl_module)
class VisImgComp(Vis2DAbstract):
def __init__(self,
*args,
random_data=0,
@ -303,45 +262,32 @@ class VisImgComp(Vis2DAbstract):
self.add_embedding = add_embedding
self.embedding_data = embedding_data
def on_train_start(self, _, pl_module):
if isinstance(pl_module.logger, TensorBoardLogger):
def on_train_start(self, trainer, pl_module):
tb = pl_module.logger.experiment
# Add embedding
if self.add_embedding:
if self.x_train is not None and self.y_train is not None:
ind = np.random.choice(len(self.x_train),
size=self.embedding_data,
replace=False)
data = self.x_train[ind]
# print(f"{data.shape=}")
# print(f"{self.y_train[ind].shape=}")
tb.add_embedding(data.view(len(ind), -1),
label_img=data,
global_step=None,
tag="Data Embedding",
metadata=self.y_train[ind],
metadata_header=None)
else:
raise ValueError("No data for add embedding flag")
# Random Data
if self.random_data:
if self.x_train is not None:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data = self.x_train[ind]
grid = torchvision.utils.make_grid(data,
nrow=self.num_columns)
grid = torchvision.utils.make_grid(data, nrow=self.num_columns)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=None,
dataformats=self.dataformats)
else:
raise ValueError("No data for random data flag")
else:
warnings.warn(
f"TensorBoardLogger is required, got {type(pl_module.logger)}")
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
@ -355,9 +301,14 @@ class VisImgComp(Vis2DAbstract):
dataformats=self.dataformats,
)
def visualize(self, pl_module):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
if self.show:
components = pl_module.components
grid = torchvision.utils.make_grid(components,
nrow=self.num_columns)
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
self.log_and_display(trainer, pl_module)

View File

@ -1,90 +0,0 @@
[project]
name = "prototorch-models"
version = "0.7.1"
description = "Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning."
authors = [
{ name = "Jensun Ravichandran", email = "jjensun@gmail.com" },
{ name = "Alexander Engelsberger", email = "engelsbe@hs-mittweida.de" },
]
dependencies = ["lightning>=2.0.0", "prototorch>=0.7.5"]
requires-python = ">=3.8"
readme = "README.md"
license = { text = "MIT" }
classifiers = [
"Development Status :: 2 - Pre-Alpha",
"Environment :: Plugins",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
]
[project.urls]
Homepage = "https://github.com/si-cim/prototorch_models"
Downloads = "https://github.com/si-cim/prototorch_models.git"
[project.optional-dependencies]
dev = ["bumpversion", "pre-commit", "yapf", "toml"]
examples = ["matplotlib", "scikit-learn"]
ci = ["pytest", "pre-commit"]
docs = [
"recommonmark",
"nbsphinx",
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-bibtex",
"sphinxcontrib-katex",
"ipykernel",
]
all = [
"bumpversion",
"pre-commit",
"yapf",
"toml",
"pytest",
"matplotlib",
"scikit-learn",
"recommonmark",
"nbsphinx",
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-bibtex",
"sphinxcontrib-katex",
"ipykernel",
]
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"
[tool.yapf]
based_on_style = "pep8"
spaces_before_comment = 2
split_before_logical_operator = true
[tool.pylint]
disable = ["too-many-arguments", "too-few-public-methods", "fixme"]
[tool.isort]
profile = "hug"
src_paths = ["isort", "test"]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 3
use_parentheses = true
line_length = 79
[tool.mypy]
explicit_package_bases = true
namespace_packages = true

8
setup.cfg Normal file
View File

@ -0,0 +1,8 @@
[isort]
profile = hug
src_paths = isort, test
[yapf]
based_on_style = pep8
spaces_before_comment = 2
split_before_logical_operator = true

93
setup.py Normal file
View File

@ -0,0 +1,93 @@
"""
######
# # ##### #### ##### #### ##### #### ##### #### # #
# # # # # # # # # # # # # # # # # #
###### # # # # # # # # # # # # # ######
# ##### # # # # # # # # ##### # # #
# # # # # # # # # # # # # # # # #
# # # #### # #### # #### # # #### # #Plugin
ProtoTorch models Plugin Package
"""
from pkg_resources import safe_name
from setuptools import find_namespace_packages, setup
PLUGIN_NAME = "models"
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh:
long_description = fh.read()
INSTALL_REQUIRES = [
"prototorch>=0.6.0",
"pytorch_lightning>=1.3.5",
"torchmetrics",
]
CLI = [
"jsonargparse",
]
DEV = [
"bumpversion",
"pre-commit",
]
DOCS = [
"recommonmark",
"sphinx",
"nbsphinx",
"sphinx_rtd_theme",
"sphinxcontrib-katex",
"sphinxcontrib-bibtex",
]
EXAMPLES = [
"matplotlib",
"scikit-learn",
]
TESTS = [
"codecov",
"pytest",
]
ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.2.0",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,
long_description_content_type="text/markdown",
author="Alexander Engelsberger",
author_email="engelsbe@hs-mittweida.de",
url=PROJECT_URL,
download_url=DOWNLOAD_URL,
license="MIT",
python_requires=">=3.9",
install_requires=INSTALL_REQUIRES,
extras_require={
"dev": DEV,
"examples": EXAMPLES,
"tests": TESTS,
"all": ALL,
},
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Environment :: Plugins",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.9",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={
"prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}"
},
packages=find_namespace_packages(include=["prototorch.*"]),
zip_safe=False,
)

14
tests/test_.py Normal file
View File

@ -0,0 +1,14 @@
"""prototorch.models test suite."""
import unittest
class TestDummy(unittest.TestCase):
def setUp(self):
pass
def test_dummy(self):
pass
def tearDown(self):
pass

View File

@ -1,27 +1,11 @@
#! /bin/bash
# Read Flags
gpu=0
while [ -n "$1" ]; do
case "$1" in
--gpu) gpu=1;;
-g) gpu=1;;
*) path=$1;;
esac
shift
done
python --version
echo "Using GPU: " $gpu
# Loop
failed=0
for example in $(find $path -maxdepth 1 -name "*.py")
for example in $(find $1 -maxdepth 1 -name "*.py")
do
echo -n "$x" $example '... '
export DISPLAY= && python $example --fast_dev_run 1 --gpus $gpu &> run_log.txt
export DISPLAY= && python $example --fast_dev_run 1 &> run_log.txt
if [[ $? -ne 0 ]]; then
echo "FAILED!!"
cat run_log.txt

View File

@ -1,193 +0,0 @@
"""prototorch.models test suite."""
import prototorch.models
def test_glvq_model_build():
model = prototorch.models.GLVQ(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_glvq1_model_build():
model = prototorch.models.GLVQ1(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_glvq21_model_build():
model = prototorch.models.GLVQ1(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_gmlvq_model_build():
model = prototorch.models.GMLVQ(
{
"distribution": (3, 2),
"input_dim": 2,
"latent_dim": 2,
},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_grlvq_model_build():
model = prototorch.models.GRLVQ(
{
"distribution": (3, 2),
"input_dim": 2,
},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_gtlvq_model_build():
model = prototorch.models.GTLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_lgmlvq_model_build():
model = prototorch.models.LGMLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_image_glvq_model_build():
model = prototorch.models.ImageGLVQ(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(16),
)
def test_image_gmlvq_model_build():
model = prototorch.models.ImageGMLVQ(
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
prototypes_initializer=prototorch.initializers.RNCI(16),
)
def test_image_gtlvq_model_build():
model = prototorch.models.ImageGMLVQ(
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
prototypes_initializer=prototorch.initializers.RNCI(16),
)
def test_siamese_glvq_model_build():
model = prototorch.models.SiameseGLVQ(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(4),
)
def test_siamese_gmlvq_model_build():
model = prototorch.models.SiameseGMLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=prototorch.initializers.RNCI(4),
)
def test_siamese_gtlvq_model_build():
model = prototorch.models.SiameseGTLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=prototorch.initializers.RNCI(4),
)
def test_knn_model_build():
train_ds = prototorch.datasets.Iris(dims=[0, 2])
model = prototorch.models.KNN(dict(k=3), data=train_ds)
def test_lvq1_model_build():
model = prototorch.models.LVQ1(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_lvq21_model_build():
model = prototorch.models.LVQ21(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_median_lvq_model_build():
model = prototorch.models.MedianLVQ(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_celvq_model_build():
model = prototorch.models.CELVQ(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_rslvq_model_build():
model = prototorch.models.RSLVQ(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_slvq_model_build():
model = prototorch.models.SLVQ(
{"distribution": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_growing_neural_gas_model_build():
model = prototorch.models.GrowingNeuralGas(
{"num_prototypes": 5},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_kohonen_som_model_build():
model = prototorch.models.KohonenSOM(
{"shape": (3, 2)},
prototypes_initializer=prototorch.initializers.RNCI(2),
)
def test_neural_gas_model_build():
model = prototorch.models.NeuralGas(
{"num_prototypes": 5},
prototypes_initializer=prototorch.initializers.RNCI(2),
)