75 Commits

Author SHA1 Message Date
Alexander Engelsberger
fe729781fc build: bump version 1.0.0a2 → 1.0.0a3 2022-06-09 14:59:07 +02:00
Alexander Engelsberger
a7df7be1c8 feat: add confusion matrix callback 2022-06-09 14:55:59 +02:00
Alexander Engelsberger
696719600b build: bump version 1.0.0a1 → 1.0.0a2 2022-06-03 11:52:50 +02:00
Alexander Engelsberger
48e7c029fa fix: Fix __init__.py 2022-06-03 11:40:45 +02:00
Alexander Engelsberger
5de3a480c7 build: bump version 0.5.2 → 1.0.0a1 2022-06-03 11:07:10 +02:00
Alexander Engelsberger
626f51ce80 ci: Add possible prerelease to bumpversion 2022-06-03 11:06:44 +02:00
Alexander Engelsberger
6d7d93c8e8 chore: rename y_arch to y 2022-06-03 10:39:11 +02:00
Jensun Ravichandran
93b1d0bd46 feat(vis): add flag to save visualization frames 2022-06-02 19:55:03 +02:00
Alexander Engelsberger
b7992c01db fix: apply hotfix 2022-06-01 14:26:37 +02:00
Alexander Engelsberger
fcd944d3ff build: bump version 0.5.1 → 0.5.2 2022-06-01 14:25:44 +02:00
Alexander Engelsberger
054720dd7b fix(hotfix): Protobuf error workaround 2022-06-01 14:14:57 +02:00
Alexander Engelsberger
23d1a71b31 feat: distribute GMLVQ into mixins 2022-05-31 17:56:03 +02:00
Alexander Engelsberger
e922aae432 feat: add GMLVQ with new architecture 2022-05-19 16:13:08 +02:00
Alexander Engelsberger
3e50d0d817 chore(protoy): mixin restructuring 2022-05-18 15:43:09 +02:00
Alexander Engelsberger
dc4f31d700 chore: rename clc-lc to proto-Y-architecture 2022-05-18 14:11:46 +02:00
Alexander Engelsberger
02954044d7 chore: improve clc-lc test 2022-05-17 17:26:03 +02:00
Alexander Engelsberger
8f08ba66ea feat: copy old clc-lc implementation 2022-05-17 16:25:43 +02:00
Alexander Engelsberger
e0b92e9ac2 chore: move mixins to seperate file 2022-05-17 16:19:47 +02:00
Alexander Engelsberger
d16a0de202 build: bump version 0.5.0 → 0.5.1 2022-05-17 12:04:08 +02:00
Alexander Engelsberger
76fea3f881 chore: update all examples to pytorch 1.6 2022-05-17 12:03:43 +02:00
Alexander Engelsberger
c00513ae0d chore: minor updates and version updates 2022-05-17 12:00:52 +02:00
Alexander Engelsberger
bccef8bef0 chore: replace relative imports 2022-05-16 11:12:53 +02:00
Alexander Engelsberger
29ee326b85 ci: Update PreCommit hooks 2022-05-16 11:11:48 +02:00
Jensun Ravichandran
055568dc86 fix: glvq_iris example works again 2022-05-09 17:33:52 +02:00
Alexander Engelsberger
3a7328e290 chore: small changes 2022-04-27 10:37:12 +02:00
Alexander Engelsberger
d6629c8792 build: bump version 0.4.1 → 0.5.0 2022-04-27 10:28:06 +02:00
Alexander Engelsberger
ef65bd3789 chore: update prototorch dependency 2022-04-27 09:50:48 +02:00
Alexander Engelsberger
d096eba2c9 chore: update pytorch lightning dependency 2022-04-27 09:39:00 +02:00
Alexander Engelsberger
dd34c57e2e ci: fix github action python version 2022-04-27 09:30:07 +02:00
Alexander Engelsberger
5911f4dd90 chore: fix errors for pytorch_lightning>1.6 2022-04-27 09:25:42 +02:00
Alexander Engelsberger
dbfe315f4f ci: add python 3.10 to tests 2022-04-27 09:24:34 +02:00
Jensun Ravichandran
9c90c902dc fix: correct typo 2022-04-04 21:54:04 +02:00
Jensun Ravichandran
7d3f59e54b test: add unit tests 2022-03-30 15:12:33 +02:00
Jensun Ravichandran
9da47b1dba fix: CBC example works again 2022-03-30 15:10:06 +02:00
Alexander Engelsberger
41f0e77fc9 fix: siameseGLVQ checks requires_grad of backbone
Necessary for different optimizer runs
2022-03-29 17:08:40 +02:00
Jensun Ravichandran
fab786a07e fix: rename hparam output_dimlatent_dim in SiameseGMLVQ 2022-03-29 15:24:42 +02:00
Jensun Ravichandran
40bd7ed380 docs: update tutorial notebook 2022-03-29 15:04:05 +02:00
Jensun Ravichandran
4941c2b89d feat: data argument optional in some visualizers 2022-03-29 11:26:22 +02:00
Jensun Ravichandran
ce14dec7e9 feat: add VisSpectralProtos 2022-03-10 15:24:44 +01:00
Jensun Ravichandran
b31c8cc707 feat: add xlabel and ylabel arguments to visualizers 2022-03-09 13:59:19 +01:00
Jensun Ravichandran
e21e6c7e02 docs: update tutorial notebook 2022-02-15 14:38:53 +01:00
Jensun Ravichandran
dd696ea1e0 fix: update hparams.distribution as it changes during training 2022-02-02 21:53:03 +01:00
Jensun Ravichandran
15e7232747 fix: ignore prototype_win_ratios by loading with strict=False 2022-02-02 21:52:01 +01:00
Jensun Ravichandran
197b728c63 feat: add visualize method to visualization callbacks
All visualization callbacks now contain a `visualize` method that takes an
appropriate PyTorchLightning Module and visualizes it without the need for a
Trainer. This is to encourage users to perform one-off visualizations after
training.
2022-02-02 21:45:44 +01:00
Jensun Ravichandran
98892afee0 chore: add example for saving/loading models from checkpoints 2022-02-02 19:02:26 +01:00
Alexander Engelsberger
d5855dbe97 fix: GLVQ can now be restored from checkpoint 2022-02-02 16:17:11 +01:00
Alexander Engelsberger
75a39f5b03 build: bump version 0.4.0 → 0.4.1 2022-01-11 18:29:55 +01:00
Alexander Engelsberger
1a0e697b27 Merge branch 'dev' into main 2022-01-11 18:29:32 +01:00
Alexander Engelsberger
1a17193b35 ci: add github actions (#16)
* chore: update pre-commit versions

* ci: remove old configurations

* ci: copy workflow from prototorch

* ci: run precommit for all files

* ci: add examples CPU test

* ci(test): failing example test

* ci: fix workflow definition

* ci(test): repeat failing example test

* ci: fix workflow definition

* ci(test): repeat failing example test II

* ci: fix test command

* ci: cleanup example test

* ci: remove travis badge
2022-01-11 18:28:50 +01:00
Alexander Engelsberger
aaa3c51e0a build: bump version 0.3.0 → 0.4.0 2021-12-09 15:58:16 +01:00
Jensun Ravichandran
62c5974a85 fix: correct typo in example script 2021-11-17 15:01:38 +01:00
Jensun Ravichandran
1d26226a2f fix(warning): specify dimension explicitly when calling softmin 2021-11-16 10:19:31 +01:00
Christoph
4232d0ed2a fix: spelling issues for previous commits 2021-11-15 11:43:39 +01:00
Christoph
a9edf06507 feat: ImageGTLVQ and SiameseGTLVQ with examples 2021-11-15 11:43:39 +01:00
Christoph
d3bb430104 feat: gtlvq with examples 2021-11-15 11:43:39 +01:00
Alexander Engelsberger
6ffd27d12a chore: Remove PytorchLightning CLI related code
Could be moved in a seperate plugin.
2021-10-11 15:16:12 +02:00
Alexander Engelsberger
859e2cae69 docs(dependencies): Add missing ipykernel dependency for docs 2021-10-11 15:11:53 +02:00
Alexander Engelsberger
d7ea89d47e feat: add simple test step 2021-09-10 19:19:51 +02:00
Jensun Ravichandran
fa928afe2c feat(vis): 2D EV projection for GMLVQ 2021-09-01 10:49:57 +02:00
Alexander Engelsberger
7d4a041df2 build: bump version 0.2.0 → 0.3.0 2021-08-30 20:50:03 +02:00
Alexander Engelsberger
04c51c00c6 ci: seperate build step 2021-08-30 20:44:16 +02:00
Alexander Engelsberger
62185b38cf chore: Update prototorch dependency 2021-08-30 20:32:47 +02:00
Alexander Engelsberger
7b93cd4ad5 feat(compatibility): Python3.6 compatibility 2021-08-30 20:32:40 +02:00
Alexander Engelsberger
d7834e2cc0 fix: All examples should work on CPU and GPU now 2021-08-05 11:20:02 +02:00
Alexander Engelsberger
0af8cf36f8 fix: labels where on cpu in forward pass 2021-08-05 09:14:32 +02:00
Jensun Ravichandran
f8ad1d83eb refactor: clean up abstract classes 2021-07-14 19:17:05 +02:00
Jensun Ravichandran
23a3683860 fix(doc): update outdated 2021-07-12 21:21:29 +02:00
Jensun Ravichandran
4be9fb81eb feat(model): implement MedianLVQ 2021-07-06 17:12:51 +02:00
Jensun Ravichandran
9d38123114 refactor: use GLVQLoss instead of LossLayer 2021-07-06 17:09:21 +02:00
Jensun Ravichandran
0f9f24e36a feat: add early-stopping and pruning to examples/warm_starting.py 2021-06-30 16:04:26 +02:00
Jensun Ravichandran
09e3ef1d0e fix: remove deprecated Trainer.accelerator_backend 2021-06-30 16:03:45 +02:00
Alexander Engelsberger
7b9b767113 fix: training loss is a zero dimensional tensor
Should fix the problem with EarlyStopping callback.
2021-06-25 17:07:06 +02:00
Jensun Ravichandran
f56ec44afe chore(github): update bug report issue template 2021-06-25 17:07:06 +02:00
Jensun Ravichandran
67a20124e8 chore(github): add issue templates 2021-06-25 17:07:06 +02:00
Jensun Ravichandran
72af03b991 refactor: use LinearTransform instead of torch.nn.Linear 2021-06-25 17:07:06 +02:00
64 changed files with 3214 additions and 980 deletions

View File

@@ -1,9 +1,11 @@
[bumpversion]
current_version = 0.2.0
current_version = 1.0.0a3
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = {major}.{minor}.{patch}
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?
serialize =
{major}.{minor}.{patch}-{release}
{major}.{minor}.{patch}
message = build: bump version {current_version} → {new_version}
[bumpversion:file:setup.py]

View File

@@ -1,15 +0,0 @@
# To validate the contents of your configuration file
# run the following command in the folder where the configuration file is located:
# codacy-analysis-cli validate-configuration --directory `pwd`
# To analyse, run:
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
---
engines:
pylintpython3:
exclude_paths:
- config/engines.yml
remark-lint:
exclude_paths:
- config/engines.yml
exclude_paths:
- 'tests/**'

View File

@@ -1,2 +0,0 @@
comment:
require_changes: yes

38
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,38 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Steps to reproduce the behavior**
1. ...
2. Run script '...' or this snippet:
```python
import prototorch as pt
...
```
3. See errors
**Expected behavior**
A clear and concise description of what you expected to happen.
**Observed behavior**
A clear and concise description of what actually happened.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**System and version information**
- OS: [e.g. Ubuntu 20.10]
- ProtoTorch Version: [e.g. 0.4.0]
- Python Version: [e.g. 3.9.5]
**Additional context**
Add any other context about the problem here.

View File

@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

25
.github/workflows/examples.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
# Thi workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: examples
on:
push:
paths:
- 'examples/**.py'
jobs:
cpu:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Run examples
run: |
./tests/test_examples.sh examples/

75
.github/workflows/pythonapp.yml vendored Normal file
View File

@@ -0,0 +1,75 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: tests
on:
push:
pull_request:
branches: [ master ]
jobs:
style:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- uses: pre-commit/action@v2.0.3
compatibility:
needs: style
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
os: [ubuntu-latest, windows-latest]
exclude:
- os: windows-latest
python-version: "3.7"
- os: windows-latest
python-version: "3.8"
- os: windows-latest
python-version: "3.9"
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: |
pytest
publish_pypi:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs: compatibility
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
pip install wheel
- name: Build package
run: python setup.py sdist bdist_wheel
- name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -3,9 +3,10 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
rev: v4.2.0
hooks:
- id: trailing-whitespace
exclude: (^\.bumpversion\.cfg$|cli_messages\.py)
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
@@ -18,19 +19,19 @@ repos:
- id: autoflake
- repo: http://github.com/PyCQA/isort
rev: 5.8.0
rev: 5.10.1
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.902
rev: v0.950
hooks:
- id: mypy
files: prototorch
additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.31.0
rev: v0.32.0
hooks:
- id: yapf
@@ -42,7 +43,7 @@ repos:
- id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v2.19.4
rev: v2.32.1
hooks:
- id: pyupgrade

View File

@@ -1,25 +0,0 @@
dist: bionic
sudo: false
language: python
python: 3.9
cache:
directories:
- "$HOME/.cache/pip"
- "./tests/artifacts"
- "$HOME/datasets"
install:
- pip install git+git://github.com/si-cim/prototorch@dev --progress-bar off
- pip install .[all] --progress-bar off
script:
- coverage run -m pytest
- ./tests/test_examples.sh examples/
after_success:
- bash <(curl -s https://codecov.io/bash)
deploy:
provider: pypi
username: __token__
password:
secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
on:
tags: true
skip_existing: true

View File

@@ -1,6 +1,5 @@
# ProtoTorch Models
[![Build Status](https://api.travis-ci.com/si-cim/prototorch_models.svg?branch=main)](https://travis-ci.com/github/si-cim/prototorch_models)
[![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch_models?color=yellow&label=version)](https://github.com/si-cim/prototorch_models/releases)
[![PyPI](https://img.shields.io/pypi/v/prototorch_models)](https://pypi.org/project/prototorch_models/)
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch_models)](https://github.com/si-cim/prototorch_models/blob/master/LICENSE)
@@ -36,6 +35,7 @@ be available for use in your Python environment as `prototorch.models`.
- Soft Learning Vector Quantization (SLVQ)
- Robust Soft Learning Vector Quantization (RSLVQ)
- Probabilistic Learning Vector Quantization (PLVQ)
- Median-LVQ
### Other
@@ -51,7 +51,6 @@ be available for use in your Python environment as `prototorch.models`.
## Planned models
- Median-LVQ
- Generalized Tangent Learning Vector Quantization (GTLVQ)
- Self-Incremental Learning Vector Quantization (SILVQ)

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "0.2.0"
release = "1.0.0-a3"
# -- General configuration ---------------------------------------------------

File diff suppressed because one or more lines are too long

View File

@@ -1,12 +1,22 @@
"""CBC example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import CBC, VisCBC2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
@@ -15,11 +25,8 @@ if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
train_loader = DataLoader(train_ds, batch_size=32)
# Hyperparameters
hparams = dict(
@@ -30,23 +37,30 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.CBC(
model = CBC(
hparams,
components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
reasonings_iniitializer=pt.initializers.
components_initializer=pt.initializers.SSCI(train_ds, noise=0.1),
reasonings_initializer=pt.initializers.
PurePositiveReasoningsInitializer(),
)
# Callbacks
vis = pt.models.VisCBC2D(data=train_ds,
vis = VisCBC2D(
data=train_ds,
title="CBC Iris Example",
resolution=100,
axis_off=True)
axis_off=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
callbacks=[
vis,
],
detect_anomaly=True,
log_every_n_steps=1,
max_epochs=1000,
)
# Training loop

View File

@@ -1,8 +0,0 @@
# Examples using Lightning CLI
Examples in this folder use the experimental [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html).
To use the example run
```
python gmlvq.py --config gmlvq.yaml
```

View File

@@ -1,20 +0,0 @@
"""GMLVQ example using the MNIST dataset."""
import torch
from pytorch_lightning.utilities.cli import LightningCLI
import prototorch as pt
from prototorch.models import ImageGMLVQ
from prototorch.models.abstract import PrototypeModel
from prototorch.models.data import MNISTDataModule
class ExperimentClass(ImageGMLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.zeros(28 * 28),
**kwargs)
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)

View File

@@ -1,11 +0,0 @@
model:
hparams:
input_dim: 784
latent_dim: 784
distribution:
num_classes: 10
prototypes_per_class: 2
proto_lr: 0.01
bb_lr: 0.01
data:
batch_size: 32

View File

@@ -1,12 +1,29 @@
"""Dynamically prune 'loser' prototypes in GLVQ-type models."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
CELVQ,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
@@ -16,15 +33,17 @@ if __name__ == "__main__":
num_classes = 4
num_features = 2
num_clusters = 1
train_ds = pt.datasets.Random(num_samples=500,
train_ds = pt.datasets.Random(
num_samples=500,
num_classes=num_classes,
num_features=num_features,
num_clusters=num_clusters,
separation=3.0,
seed=42)
seed=42,
)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
train_loader = DataLoader(train_ds, batch_size=256)
# Hyperparameters
prototypes_per_class = num_clusters * 5
@@ -34,7 +53,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.CELVQ(
model = CELVQ(
hparams,
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
)
@@ -43,18 +62,18 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
logging.info(model)
# Callbacks
vis = pt.models.VisGLVQ2D(train_ds)
pruning = pt.models.PruneLoserPrototypes(
vis = VisGLVQ2D(train_ds)
pruning = PruneLoserPrototypes(
threshold=0.01, # prune prototype if it wins less than 1%
idle_epochs=20, # pruning too early may cause problems
prune_quota_per_epoch=2, # prune at most 2 prototypes per epoch
frequency=1, # prune every epoch
verbose=True,
)
es = pl.callbacks.EarlyStopping(
es = EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=20,
@@ -71,10 +90,9 @@ if __name__ == "__main__":
pruning,
es,
],
progress_bar_refresh_rate=0,
terminate_on_nan=True,
weights_summary="full",
accelerator="ddp",
detect_anomaly=True,
log_every_n_steps=1,
max_epochs=1000,
)
# Training loop

View File

@@ -1,13 +1,24 @@
"""GLVQ example using the Iris dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import GLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
@@ -17,7 +28,7 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
train_loader = DataLoader(train_ds, batch_size=64, num_workers=4)
# Hyperparameters
hparams = dict(
@@ -29,7 +40,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.GLVQ(
model = GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
@@ -41,15 +52,28 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
vis = VisGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
weights_summary="full",
accelerator="ddp",
callbacks=[
vis,
],
max_epochs=100,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./glvq_iris.ckpt")
# Load saved model
new_model = GLVQ.load_from_checkpoint(
checkpoint_path="./glvq_iris.ckpt",
strict=False,
)
logging.info(new_model)

73
examples/gmlvq_iris.py Normal file
View File

@@ -0,0 +1,73 @@
"""GMLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import GMLVQ, VisGMLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
input_dim=4,
latent_dim=4,
distribution={
"num_classes": 3,
"per_class": 2
},
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = GMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 4)
# Callbacks
vis = VisGMLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
max_epochs=100,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,14 +1,29 @@
"""GMLVQ example using the MNIST dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
ImageGMLVQ,
PruneLoserPrototypes,
VisImgComp,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
@@ -33,12 +48,8 @@ if __name__ == "__main__":
)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=256)
test_loader = torch.utils.data.DataLoader(test_ds,
num_workers=0,
batch_size=256)
train_loader = DataLoader(train_ds, num_workers=4, batch_size=256)
test_loader = DataLoader(test_ds, num_workers=4, batch_size=256)
# Hyperparameters
num_classes = 10
@@ -52,14 +63,14 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.ImageGMLVQ(
model = ImageGMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
)
# Callbacks
vis = pt.models.VisImgComp(
vis = VisImgComp(
data=train_ds,
num_columns=10,
show=False,
@@ -69,14 +80,14 @@ if __name__ == "__main__":
embedding_data=200,
flatten_data=False,
)
pruning = pt.models.PruneLoserPrototypes(
pruning = PruneLoserPrototypes(
threshold=0.01,
idle_epochs=1,
prune_quota_per_epoch=10,
frequency=1,
verbose=True,
)
es = pl.callbacks.EarlyStopping(
es = EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=15,
@@ -90,11 +101,11 @@ if __name__ == "__main__":
callbacks=[
vis,
pruning,
# es,
es,
],
terminate_on_nan=True,
weights_summary=None,
# accelerator="ddp",
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop

View File

@@ -1,12 +1,28 @@
"""GLVQ example using the spiral dataset."""
"""GMLVQ example using the spiral dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
GMLVQ,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
@@ -16,7 +32,7 @@ if __name__ == "__main__":
train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
train_loader = DataLoader(train_ds, batch_size=256)
# Hyperparameters
num_classes = 2
@@ -32,19 +48,19 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.GMLVQ(
model = GMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
)
# Callbacks
vis = pt.models.VisGLVQ2D(
vis = VisGLVQ2D(
train_ds,
show_last_only=False,
block=False,
)
pruning = pt.models.PruneLoserPrototypes(
pruning = PruneLoserPrototypes(
threshold=0.01,
idle_epochs=10,
prune_quota_per_epoch=5,
@@ -53,7 +69,7 @@ if __name__ == "__main__":
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
verbose=True,
)
es = pl.callbacks.EarlyStopping(
es = EarlyStopping(
monitor="train_loss",
min_delta=1.0,
patience=5,
@@ -66,10 +82,12 @@ if __name__ == "__main__":
args,
callbacks=[
vis,
# es, # FIXME
es,
pruning,
],
terminate_on_nan=True,
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop

View File

@@ -1,10 +1,19 @@
"""Growing Neural Gas example using the Iris dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import GrowingNeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
@@ -13,11 +22,11 @@ if __name__ == "__main__":
args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
seed_everything(seed=42)
# Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
@@ -27,7 +36,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.GrowingNeuralGas(
model = GrowingNeuralGas(
hparams,
prototypes_initializer=pt.initializers.ZCI(2),
)
@@ -36,17 +45,20 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Model summary
print(model)
logging.info(model)
# Callbacks
vis = pt.models.VisNG2D(data=train_loader)
vis = VisNG2D(data=train_loader)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
max_epochs=100,
callbacks=[vis],
weights_summary="full",
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop

116
examples/gtlvq_mnist.py Normal file
View File

@@ -0,0 +1,116 @@
"""GTLVQ example using the MNIST dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
ImageGTLVQ,
PruneLoserPrototypes,
VisImgComp,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = MNIST(
"~/datasets",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
test_ds = MNIST(
"~/datasets",
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=256)
test_loader = DataLoader(test_ds, num_workers=0, batch_size=256)
# Hyperparameters
num_classes = 10
prototypes_per_class = 1
hparams = dict(
input_dim=28 * 28,
latent_dim=28,
distribution=(num_classes, prototypes_per_class),
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = ImageGTLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
#Use one batch of data for subspace initiator.
omega_initializer=pt.initializers.PCALinearTransformInitializer(
next(iter(train_loader))[0].reshape(256, 28 * 28)))
# Callbacks
vis = VisImgComp(
data=train_ds,
num_columns=10,
show=False,
tensorboard=True,
random_data=100,
add_embedding=True,
embedding_data=200,
flatten_data=False,
)
pruning = PruneLoserPrototypes(
threshold=0.01,
idle_epochs=1,
prune_quota_per_epoch=10,
frequency=1,
verbose=True,
)
es = EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=15,
mode="min",
check_on_train_epoch_end=True,
)
# Setup trainer
# using GPUs here is strongly recommended!
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

76
examples/gtlvq_moons.py Normal file
View File

@@ -0,0 +1,76 @@
"""Localized-GTLVQ example using the Moons dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import GTLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = DataLoader(
train_ds,
batch_size=256,
shuffle=True,
)
# Hyperparameters
# Latent_dim should be lower than input dim.
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1)
# Initialize the model
model = GTLVQ(hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds))
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
logging.info(model)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
es = EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
mode="max",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,12 +1,19 @@
"""k-NN example using the Iris dataset from scikit-learn."""
import argparse
import pytorch_lightning as pl
import torch
from sklearn.datasets import load_iris
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import KNN, VisGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__":
# Command-line arguments
@@ -15,28 +22,38 @@ if __name__ == "__main__":
args = parser.parse_args()
# Dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
X, y = load_iris(return_X_y=True)
X = X[:, 0:3:2]
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
test_size=0.5,
random_state=42,
)
train_ds = pt.datasets.NumpyDataset(X_train, y_train)
test_ds = pt.datasets.NumpyDataset(X_test, y_test)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
train_loader = DataLoader(train_ds, batch_size=16)
test_loader = DataLoader(test_ds, batch_size=16)
# Hyperparameters
hparams = dict(k=5)
# Initialize the model
model = pt.models.KNN(hparams, data=train_ds)
model = KNN(hparams, data=train_ds)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
logging.info(model)
# Callbacks
vis = pt.models.VisGLVQ2D(
data=(x_train, y_train),
vis = VisGLVQ2D(
data=(X_train, y_train),
resolution=200,
block=True,
)
@@ -45,8 +62,11 @@ if __name__ == "__main__":
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=1,
callbacks=[vis],
weights_summary="full",
callbacks=[
vis,
],
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
@@ -54,5 +74,8 @@ if __name__ == "__main__":
trainer.fit(model, train_loader)
# Recall
y_pred = model.predict(torch.tensor(x_train))
print(y_pred)
y_pred = model.predict(torch.tensor(X_train))
logging.info(y_pred)
# Test
trainer.test(model, dataloaders=test_loader)

View File

@@ -1,15 +1,25 @@
"""Kohonen Self Organizing Map."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.models import KohonenSOM
from prototorch.utils.colors import hex_to_rgb
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader, TensorDataset
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Vis2DColorSOM(pl.Callback):
def __init__(self, data, title="ColorSOMe", pause_time=0.1):
super().__init__()
self.title = title
@@ -17,7 +27,7 @@ class Vis2DColorSOM(pl.Callback):
self.data = data
self.pause_time = pause_time
def on_epoch_end(self, trainer, pl_module):
def on_train_epoch_end(self, trainer, pl_module: KohonenSOM):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
@@ -30,12 +40,14 @@ class Vis2DColorSOM(pl.Callback):
d = pl_module.compute_distances(self.data)
wp = pl_module.predict_from_distances(d)
for i, iloc in enumerate(wp):
plt.text(iloc[1],
plt.text(
iloc[1],
iloc[0],
cnames[i],
color_names[i],
ha="center",
va="center",
bbox=dict(facecolor="white", alpha=0.5, lw=0))
bbox=dict(facecolor="white", alpha=0.5, lw=0),
)
if trainer.current_epoch != trainer.max_epochs - 1:
plt.pause(self.pause_time)
@@ -50,7 +62,7 @@ if __name__ == "__main__":
args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
seed_everything(seed=42)
# Prepare the data
hex_colors = [
@@ -58,15 +70,15 @@ if __name__ == "__main__":
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
"#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
]
cnames = [
color_names = [
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
"red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
"lightgrey", "olive", "purple", "orange"
]
colors = list(hex_to_rgb(hex_colors))
data = torch.Tensor(colors) / 255.0
train_ds = torch.utils.data.TensorDataset(data)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
train_ds = TensorDataset(data)
train_loader = DataLoader(train_ds, batch_size=8)
# Hyperparameters
hparams = dict(
@@ -77,7 +89,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.KohonenSOM(
model = KohonenSOM(
hparams,
prototypes_initializer=pt.initializers.RNCI(3),
)
@@ -86,7 +98,7 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 3)
# Model summary
print(model)
logging.info(model)
# Callbacks
vis = Vis2DColorSOM(data=data)
@@ -95,8 +107,11 @@ if __name__ == "__main__":
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=500,
callbacks=[vis],
weights_summary="full",
callbacks=[
vis,
],
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop

View File

@@ -1,10 +1,20 @@
"""Localized-GMLVQ example using the Moons dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import LGMLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
@@ -13,15 +23,13 @@ if __name__ == "__main__":
args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=256,
shuffle=True)
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
# Hyperparameters
hparams = dict(
@@ -31,7 +39,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.LGMLVQ(
model = LGMLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
)
@@ -40,11 +48,11 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
logging.info(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
vis = VisGLVQ2D(data=train_ds)
es = EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
@@ -60,8 +68,9 @@ if __name__ == "__main__":
vis,
es,
],
weights_summary="full",
accelerator="ddp",
log_every_n_steps=1,
max_epochs=1000,
detect_anomaly=True,
)
# Training loop

View File

@@ -1,13 +1,26 @@
"""LVQMLN example using all four dimensions of the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
LVQMLN,
PruneLoserPrototypes,
VisSiameseGLVQ2D,
)
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
@@ -33,10 +46,10 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris()
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
seed_everything(seed=42)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
@@ -49,7 +62,7 @@ if __name__ == "__main__":
backbone = Backbone()
# Initialize the model
model = pt.models.LVQMLN(
model = LVQMLN(
hparams,
prototypes_initializer=pt.initializers.SSCI(
train_ds,
@@ -58,18 +71,15 @@ if __name__ == "__main__":
backbone=backbone,
)
# Model summary
print(model)
# Callbacks
vis = pt.models.VisSiameseGLVQ2D(
vis = VisSiameseGLVQ2D(
data=train_ds,
map_protos=False,
border=0.1,
resolution=500,
axis_off=True,
)
pruning = pt.models.PruneLoserPrototypes(
pruning = PruneLoserPrototypes(
threshold=0.01,
idle_epochs=20,
prune_quota_per_epoch=2,
@@ -84,6 +94,9 @@ if __name__ == "__main__":
vis,
pruning,
],
log_every_n_steps=1,
max_epochs=1000,
detect_anomaly=True,
)
# Training loop

View File

@@ -0,0 +1,68 @@
"""Median-LVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import MedianLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = DataLoader(
train_ds,
batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches
)
# Initialize the model
model = MedianLVQ(
hparams=dict(distribution=(3, 2), lr=0.01),
prototypes_initializer=pt.initializers.SSCI(train_ds),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
es = EarlyStopping(
monitor="train_acc",
min_delta=0.01,
patience=5,
mode="max",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,15 +1,26 @@
"""Neural Gas example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import NeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
@@ -17,7 +28,7 @@ if __name__ == "__main__":
# Prepare and pre-process the dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
x_train = x_train[:, 0:3:2]
scaler = StandardScaler()
scaler.fit(x_train)
x_train = scaler.transform(x_train)
@@ -25,7 +36,7 @@ if __name__ == "__main__":
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
@@ -35,7 +46,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.NeuralGas(
model = NeuralGas(
hparams,
prototypes_initializer=pt.core.ZCI(2),
lr_scheduler=ExponentialLR,
@@ -45,17 +56,18 @@ if __name__ == "__main__":
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Model summary
print(model)
# Callbacks
vis = pt.models.VisNG2D(data=train_ds)
vis = VisNG2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
weights_summary="full",
callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop

View File

@@ -1,10 +1,18 @@
"""RSLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import RSLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
@@ -13,13 +21,13 @@ if __name__ == "__main__":
args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
seed_everything(seed=42)
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
@@ -33,7 +41,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.RSLVQ(
model = RSLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
@@ -42,19 +50,18 @@ if __name__ == "__main__":
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
vis = VisGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
terminate_on_nan=True,
weights_summary="full",
accelerator="ddp",
callbacks=[
vis,
],
detect_anomaly=True,
max_epochs=100,
log_every_n_steps=1,
)
# Training loop

View File

@@ -1,13 +1,22 @@
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
@@ -33,10 +42,10 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris()
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
seed_everything(seed=2)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
@@ -49,23 +58,25 @@ if __name__ == "__main__":
backbone = Backbone()
# Initialize the model
model = pt.models.SiameseGLVQ(
model = SiameseGLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone,
both_path_gradients=False,
)
# Model summary
print(model)
# Callbacks
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop

View File

@@ -0,0 +1,85 @@
"""Siamese GTLVQ example using all four dimensions of the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.latent_size = latent_size
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
self.activation = torch.nn.Sigmoid()
def forward(self, x):
x = self.activation(self.dense1(x))
out = self.activation(self.dense2(x))
return out
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Reproducibility
seed_everything(seed=2)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
distribution=[1, 2, 3],
proto_lr=0.01,
bb_lr=0.01,
input_dim=2,
latent_dim=1,
)
# Initialize the backbone
backbone = Backbone(latent_size=hparams["input_dim"])
# Initialize the model
model = SiameseGTLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone,
both_path_gradients=False,
)
# Callbacks
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,13 +1,30 @@
"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
GLVQ,
KNN,
GrowingNeuralGas,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
@@ -15,10 +32,10 @@ if __name__ == "__main__":
# Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
train_loader = DataLoader(train_ds, batch_size=64, num_workers=0)
# Initialize the gng
gng = pt.models.GrowingNeuralGas(
gng = GrowingNeuralGas(
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
prototypes_initializer=pt.initializers.ZCI(2),
lr_scheduler=ExponentialLR,
@@ -26,7 +43,7 @@ if __name__ == "__main__":
)
# Callbacks
es = pl.callbacks.EarlyStopping(
es = EarlyStopping(
monitor="loss",
min_delta=0.001,
patience=20,
@@ -37,9 +54,12 @@ if __name__ == "__main__":
# Setup trainer for GNG
trainer = pl.Trainer(
max_epochs=200,
callbacks=[es],
weights_summary=None,
max_epochs=1000,
callbacks=[
es,
],
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
@@ -52,12 +72,12 @@ if __name__ == "__main__":
)
# Warm-start prototypes
knn = pt.models.KNN(dict(k=1), data=train_ds)
knn = KNN(dict(k=1), data=train_ds)
prototypes = gng.prototypes
plabels = knn.predict(prototypes)
# Initialize the model
model = pt.models.GLVQ(
model = GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.LCI(prototypes),
@@ -70,14 +90,34 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
vis = VisGLVQ2D(data=train_ds)
pruning = PruneLoserPrototypes(
threshold=0.02,
idle_epochs=2,
prune_quota_per_epoch=5,
frequency=1,
verbose=True,
)
es = EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=10,
mode="min",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
weights_summary="full",
accelerator="ddp",
callbacks=[
vis,
pruning,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop

View File

@@ -0,0 +1,88 @@
import prototorch as pt
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.y.callbacks import (
LogTorchmetricCallback,
PlotLambdaMatrixToTensorboard,
VisGMLVQ2D,
)
from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
# ##############################################################################
if __name__ == "__main__":
# ------------------------------------------------------------
# DATA
# ------------------------------------------------------------
# Dataset
train_ds = pt.datasets.Iris()
# Dataloader
train_loader = DataLoader(
train_ds,
batch_size=32,
num_workers=0,
shuffle=True,
)
# ------------------------------------------------------------
# HYPERPARAMETERS
# ------------------------------------------------------------
# Select Initializer
components_initializer = SMCI(train_ds)
# Define Hyperparameters
hyperparameters = GMLVQ.HyperParameters(
lr=dict(components_layer=0.1, _omega=0),
input_dim=4,
distribution=dict(
num_classes=3,
per_class=1,
),
component_initializer=components_initializer,
)
# Create Model
model = GMLVQ(hyperparameters)
print(model)
# ------------------------------------------------------------
# TRAINING
# ------------------------------------------------------------
# Controlling Callbacks
stopping_criterion = LogTorchmetricCallback(
'recall',
torchmetrics.Recall,
num_classes=3,
)
es = EarlyStopping(
monitor=stopping_criterion.name,
mode="max",
patience=10,
)
# Visualization Callback
vis = VisGMLVQ2D(data=train_ds)
# Define trainer
trainer = pl.Trainer(
callbacks=[
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
],
max_epochs=1000,
)
# Train
trainer.fit(model, train_loader)

View File

@@ -1,7 +1,5 @@
"""`models` plugin for the `prototorch` package."""
from importlib.metadata import PackageNotFoundError, version
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC
from .glvq import (
@@ -10,17 +8,32 @@ from .glvq import (
GLVQ21,
GMLVQ,
GRLVQ,
GTLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
ImageGTLVQ,
SiameseGLVQ,
SiameseGMLVQ,
SiameseGTLVQ,
)
from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
from .lvq import (
LVQ1,
LVQ21,
MedianLVQ,
)
from .probabilistic import (
CELVQ,
RSLVQ,
SLVQ,
)
from .unsupervised import (
GrowingNeuralGas,
KohonenSOM,
NeuralGas,
)
from .vis import *
__version__ = "0.2.0"
__version__ = "1.0.0-a3"

View File

@@ -1,33 +1,38 @@
"""Abstract classes to be inherited by prototorch models."""
from typing import Final, final
import logging
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics
from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents
from ..core.distances import euclidean_distance
from ..core.initializers import LabelsInitializer
from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
class ProtoTorchMixin(object):
pass
from prototorch.core.competitions import WTAC
from prototorch.core.components import (
AbstractComponents,
Components,
LabeledComponents,
)
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
LabelsInitializer,
ZerosCompInitializer,
)
from prototorch.core.pooling import stratified_min_pooling
from prototorch.nn.wrappers import LambdaLayer
class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts."""
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
wrapped = f"ProtoTorch Bolt(\n{indented})"
return wrapped
"""All ProtoTorch models are ProtoTorch Bolts.
hparams:
- lr: learning rate
kwargs:
- optimizer: optimizer class
- lr_scheduler: learning rate scheduler class
- lr_scheduler_kwargs: learning rate scheduler kwargs
"""
class PrototypeModel(ProtoTorchBolt):
def __init__(self, hparams, **kwargs):
super().__init__()
@@ -42,6 +47,43 @@ class PrototypeModel(ProtoTorchBolt):
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams["lr"])
if self.lr_scheduler is not None:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
else:
return optimizer
def reconfigure_optimizers(self):
if self.trainer:
self.trainer.strategy.setup_optimizers(self.trainer)
else:
logging.warning("No trainer to reconfigure optimizers!")
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
wrapped = f"ProtoTorch Bolt(\n{indented})"
return wrapped
class PrototypeModel(ProtoTorchBolt):
"""Abstract Prototype Model
kwargs:
- distance_fn: distance function
"""
proto_layer: AbstractComponents
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
@@ -58,33 +100,20 @@ class PrototypeModel(ProtoTorchBolt):
"""Only an alias for the prototypes."""
return self.prototypes
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
if self.lr_scheduler is not None:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
else:
return optimizer
@final
def reconfigure_optimizers(self):
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
self.hparams["distribution"] = self.proto_layer.distribution
self.reconfigure_optimizers()
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.hparams["distribution"] = self.proto_layer.distribution
self.reconfigure_optimizers()
class UnsupervisedPrototypeModel(PrototypeModel):
proto_layer: Components
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@@ -92,12 +121,12 @@ class UnsupervisedPrototypeModel(PrototypeModel):
prototypes_initializer = kwargs.get("prototypes_initializer", None)
if prototypes_initializer is not None:
self.proto_layer = Components(
self.hparams.num_prototypes,
self.hparams["num_prototypes"],
initializer=prototypes_initializer,
)
def compute_distances(self, x):
protos = self.proto_layer()
protos = self.proto_layer().type_as(x)
distances = self.distance_layer(x, protos)
return distances
@@ -107,19 +136,34 @@ class UnsupervisedPrototypeModel(PrototypeModel):
class SupervisedPrototypeModel(PrototypeModel):
def __init__(self, hparams, **kwargs):
proto_layer: LabeledComponents
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
distribution = hparams.get("distribution", None)
prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer())
if not skip_proto_layer:
# when subclasses do not need a customized prototype layer
if prototypes_initializer is not None:
# when building a new model
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
distribution=distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
proto_shape = self.proto_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
)
self.competition_layer = WTAC()
@property
@@ -137,14 +181,14 @@ class SupervisedPrototypeModel(PrototypeModel):
def forward(self, x):
distances = self.compute_distances(x)
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels)
y_pred = torch.nn.functional.softmin(winning)
y_pred = F.softmin(winning, dim=1)
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
y_pred = self.competition_layer(distances, plabels)
return y_pred
@@ -166,27 +210,10 @@ class SupervisedPrototypeModel(PrototypeModel):
prog_bar=True,
logger=True)
def test_step(self, batch, batch_idx):
x, targets = batch
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization: Final = False
preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
@final
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()
self.log("test_acc", accuracy)

View File

@@ -1,24 +1,30 @@
"""Lightning Callbacks."""
import logging
from typing import TYPE_CHECKING
import pytorch_lightning as pl
import torch
from prototorch.core.initializers import LiteralCompInitializer
from ..core.components import Components
from ..core.initializers import LiteralCompInitializer
from .extras import ConnectionTopology
if TYPE_CHECKING:
from prototorch.models import GLVQ, GrowingNeuralGas
class PruneLoserPrototypes(pl.Callback):
def __init__(self,
def __init__(
self,
threshold=0.01,
idle_epochs=10,
prune_quota_per_epoch=-1,
frequency=1,
replace=False,
prototypes_initializer=None,
verbose=False):
verbose=False,
):
self.threshold = threshold # minimum win ratio
self.idle_epochs = idle_epochs # epochs to wait before pruning
self.prune_quota_per_epoch = prune_quota_per_epoch
@@ -27,56 +33,59 @@ class PruneLoserPrototypes(pl.Callback):
self.verbose = verbose
self.prototypes_initializer = prototypes_initializer
def on_epoch_end(self, trainer, pl_module):
def on_train_epoch_end(self, trainer, pl_module: "GLVQ"):
if (trainer.current_epoch + 1) < self.idle_epochs:
return None
if (trainer.current_epoch + 1) % self.frequency:
return None
ratios = pl_module.prototype_win_ratios.mean(dim=0)
to_prune = torch.arange(len(ratios))[ratios < self.threshold]
to_prune = to_prune.tolist()
to_prune_tensor = torch.arange(len(ratios))[ratios < self.threshold]
to_prune = to_prune_tensor.tolist()
prune_labels = pl_module.prototype_labels[to_prune]
if self.prune_quota_per_epoch > 0:
to_prune = to_prune[:self.prune_quota_per_epoch]
prune_labels = prune_labels[:self.prune_quota_per_epoch]
if len(to_prune) > 0:
if self.verbose:
print(f"\nPrototype win ratios: {ratios}")
print(f"Pruning prototypes at: {to_prune}")
print(f"Corresponding labels are: {prune_labels.tolist()}")
logging.debug(f"\nPrototype win ratios: {ratios}")
logging.debug(f"Pruning prototypes at: {to_prune}")
logging.debug(f"Corresponding labels are: {prune_labels.tolist()}")
cur_num_protos = pl_module.num_prototypes
pl_module.remove_prototypes(indices=to_prune)
if self.replace:
labels, counts = torch.unique(prune_labels,
sorted=True,
return_counts=True)
distribution = dict(zip(labels.tolist(), counts.tolist()))
if self.verbose:
print(f"Re-adding pruned prototypes...")
print(f"{distribution=}")
logging.info(f"Re-adding pruned prototypes...")
logging.debug(f"distribution={distribution}")
pl_module.add_prototypes(
distribution=distribution,
components_initializer=self.prototypes_initializer)
new_num_protos = pl_module.num_prototypes
if self.verbose:
print(f"`num_prototypes` changed from {cur_num_protos} "
logging.info(f"`num_prototypes` changed from {cur_num_protos} "
f"to {new_num_protos}.")
return True
class PrototypeConvergence(pl.Callback):
def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False):
self.min_delta = min_delta
self.idle_epochs = idle_epochs # epochs to wait
self.verbose = verbose
def on_epoch_end(self, trainer, pl_module):
def on_train_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.idle_epochs:
return None
if self.verbose:
print("Stopping...")
logging.info("Stopping...")
# TODO
return True
@@ -89,16 +98,21 @@ class GNGCallback(pl.Callback):
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
"""
def __init__(self, reduction=0.1, freq=10):
self.reduction = reduction
self.freq = freq
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
def on_train_epoch_end(
self,
trainer: pl.Trainer,
pl_module: "GrowingNeuralGas",
):
if (trainer.current_epoch + 1) % self.freq == 0:
# Get information
errors = pl_module.errors
topology: ConnectionTopology = pl_module.topology_layer
components: Components = pl_module.proto_layer.components
components = pl_module.proto_layer.components
# Insertion point
worst = torch.argmax(errors)
@@ -118,8 +132,9 @@ class GNGCallback(pl.Callback):
# Add component
pl_module.proto_layer.add_components(
None,
initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
1,
initializer=LiteralCompInitializer(new_component.unsqueeze(0)),
)
# Adjust Topology
topology.add_prototype()
@@ -134,4 +149,4 @@ class GNGCallback(pl.Callback):
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator_backend.setup_optimizers(trainer)
trainer.strategy.setup_optimizers(trainer)

View File

@@ -1,27 +1,30 @@
import torch
import torch.nn.functional as F
import torchmetrics
from prototorch.core.competitions import CBCC
from prototorch.core.components import ReasoningComponents
from prototorch.core.initializers import RandomReasoningsInitializer
from prototorch.core.losses import MarginLoss
from prototorch.core.similarities import euclidean_similarity
from prototorch.nn.wrappers import LambdaLayer
from ..core.competitions import CBCC
from ..core.components import ReasoningComponents
from ..core.initializers import RandomReasoningsInitializer
from ..core.losses import MarginLoss
from ..core.similarities import euclidean_similarity
from ..nn.wrappers import LambdaLayer
from .abstract import ImagePrototypesMixin
from .glvq import SiameseGLVQ
from .mixins import ImagePrototypesMixin
class CBC(SiameseGLVQ):
"""Classification-By-Components."""
proto_layer: ReasoningComponents
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
super().__init__(hparams, skip_proto_layer=True, **kwargs)
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
components_initializer = kwargs.get("components_initializer", None)
reasonings_initializer = kwargs.get("reasonings_initializer",
RandomReasoningsInitializer())
self.components_layer = ReasoningComponents(
self.hparams.distribution,
self.hparams["distribution"],
components_initializer=components_initializer,
reasonings_initializer=reasonings_initializer,
)
@@ -31,7 +34,7 @@ class CBC(SiameseGLVQ):
# Namespace hook
self.proto_layer = self.components_layer
self.loss = MarginLoss(self.hparams.margin)
self.loss = MarginLoss(self.hparams["margin"])
def forward(self, x):
components, reasonings = self.components_layer()
@@ -47,8 +50,8 @@ class CBC(SiameseGLVQ):
x, y = batch
y_pred = self(x)
num_classes = self.num_classes
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
loss = self.loss(y_pred, y_true).mean(dim=0)
y_true = F.one_hot(y.long(), num_classes=num_classes)
loss = self.loss(y_pred, y_true).mean()
return y_pred, loss
def training_step(self, batch, batch_idx, optimizer_idx=None):

View File

@@ -1,124 +0,0 @@
"""Prototorch Data Modules
This allows to store the used dataset inside a Lightning Module.
Mainly used for PytorchLightningCLI configurations.
"""
from typing import Any, Optional, Type
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import prototorch as pt
# MNIST
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
# Download mnist dataset as side-effect, only called on the first cpu
def prepare_data(self):
MNIST("~/datasets", train=True, download=True)
MNIST("~/datasets", train=False, download=True)
# called for every GPU/machine (assigning state is OK)
def setup(self, stage=None):
# Transforms
transform = transforms.Compose([
transforms.ToTensor(),
])
# Split dataset
if stage in (None, "fit"):
mnist_train = MNIST("~/datasets", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(
mnist_train,
[55000, 5000],
)
if stage == (None, "test"):
self.mnist_test = MNIST(
"~/datasets",
train=False,
transform=transform,
)
# Dataloaders
def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return mnist_train
def val_dataloader(self):
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
return mnist_val
def test_dataloader(self):
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
return mnist_test
# def train_on_mnist(batch_size=256) -> type:
# class DataClass(pl.LightningModule):
# datamodule = MNISTDataModule(batch_size=batch_size)
# def __init__(self, *args, **kwargs):
# prototype_initializer = kwargs.pop(
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# dc: Type[DataClass] = DataClass
# return dc
# ABSTRACT
class GeneralDataModule(pl.LightningDataModule):
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
super().__init__()
self.train_dataset = dataset
self.batch_size = batch_size
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset, batch_size=self.batch_size)
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
# class DataClass(pl.LightningModule):
# datamodule = GeneralDataModule(dataset, batch_size)
# datashape = dataset[0][0].shape
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
# def __init__(self, *args: Any, **kwargs: Any) -> None:
# prototype_initializer = kwargs.pop(
# "prototype_initializer",
# pt.components.Zeros(self.datashape),
# )
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# return DataClass
# if __name__ == "__main__":
# from prototorch.models import GLVQ
# demo_dataset = pt.datasets.Iris()
# TrainingClass: Type = train_on_dataset(demo_dataset)
# class DemoGLVQ(TrainingClass, GLVQ):
# """Model Definition."""
# # Hyperparameters
# hparams = dict(
# distribution={
# "num_classes": 3,
# "prototypes_per_class": 4
# },
# lr=0.01,
# )
# initialized = DemoGLVQ(hparams)
# print(initialized)

View File

@@ -5,8 +5,7 @@ Modules not yet available in prototorch go here temporarily.
"""
import torch
from ..core.similarities import gaussian
from prototorch.core.similarities import gaussian
def rank_scaled_gaussian(distances, lambd):
@@ -15,7 +14,46 @@ def rank_scaled_gaussian(distances, lambd):
return torch.exp(-torch.exp(-ranks / lambd) * distances)
def orthogonalization(tensors):
"""Orthogonalization via polar decomposition """
u, _, v = torch.svd(tensors, compute_uv=True)
u_shape = tuple(list(u.shape))
v_shape = tuple(list(v.shape))
# reshape to (num x N x M)
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
out = u @ v.permute([0, 2, 1])
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
return out
def ltangent_distance(x, y, omegas):
r"""Localized Tangent distance.
Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
:param `torch.tensor` omegas: Three dimensional matrix
:rtype: `torch.tensor`
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
omegas, omegas.permute([0, 2, 1]))
projected_x = x @ p
projected_y = torch.diagonal(y @ p).T
expanded_y = torch.unsqueeze(projected_y, dim=1)
batchwise_difference = expanded_y - projected_x
differences_squared = batchwise_difference**2
distances = torch.sqrt(torch.sum(differences_squared, dim=2))
distances = distances.permute(1, 0)
return distances
class GaussianPrior(torch.nn.Module):
def __init__(self, variance):
super().__init__()
self.variance = variance
@@ -25,6 +63,7 @@ class GaussianPrior(torch.nn.Module):
class RankScaledGaussianPrior(torch.nn.Module):
def __init__(self, lambd):
super().__init__()
self.lambd = lambd
@@ -34,6 +73,7 @@ class RankScaledGaussianPrior(torch.nn.Module):
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()
self.agelimit = agelimit

View File

@@ -1,49 +1,70 @@
"""Models based on the GLVQ framework."""
import torch
from prototorch.core.competitions import wtac
from prototorch.core.distances import (
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.core.initializers import EyeLinearTransformInitializer
from prototorch.core.losses import (
GLVQLoss,
lvq1_loss,
lvq21_loss,
)
from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from ..core.initializers import EyeTransformInitializer
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .abstract import SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization
from .mixins import ImagePrototypesMixin
class GLVQ(SupervisedPrototypeModel):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Default hparams
self.hparams.setdefault("margin", 0.0)
self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
# Layers
transfer_fn = get_activation(self.hparams.transfer_fn)
self.transfer_layer = LambdaLayer(transfer_fn)
# Loss
self.loss = LossLayer(glvq_loss)
self.loss = GLVQLoss(
margin=self.hparams["margin"],
transfer_fn=self.hparams["transfer_fn"],
beta=self.hparams["transfer_beta"],
)
# def on_save_checkpoint(self, checkpoint):
# if "prototype_win_ratios" in checkpoint["state_dict"]:
# del checkpoint["state_dict"]["prototype_win_ratios"]
def initialize_prototype_win_ratios(self):
self.register_buffer(
"prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device))
torch.zeros(self.num_prototypes, device=self.device),
)
def on_epoch_start(self):
def on_train_epoch_start(self):
self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances):
batch_size = len(distances)
prototype_wc = torch.zeros(self.num_prototypes,
prototype_wc = torch.zeros(
self.num_prototypes,
dtype=torch.long,
device=self.device)
wi, wc = torch.unique(distances.min(dim=-1).indices,
device=self.device,
)
wi, wc = torch.unique(
distances.min(dim=-1).indices,
sorted=True,
return_counts=True)
return_counts=True,
)
prototype_wc[wi] = wc
prototype_wr = prototype_wc / batch_size
self.prototype_win_ratios = torch.vstack([
@@ -54,10 +75,8 @@ class GLVQ(SupervisedPrototypeModel):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x)
plabels = self.proto_layer.labels
mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
_, plabels = self.proto_layer()
loss = self.loss(out, y, plabels)
return out, loss
def training_step(self, batch, batch_idx, optimizer_idx=None):
@@ -68,14 +87,12 @@ class GLVQ(SupervisedPrototypeModel):
return train_loss
def validation_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, val_loss = self.shared_step(batch, batch_idx)
self.log("val_loss", val_loss)
self.log_acc(out, batch[-1], tag="val_acc")
return val_loss
def test_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, test_loss = self.shared_step(batch, batch_idx)
self.log_acc(out, batch[-1], tag="test_acc")
return test_loss
@@ -86,10 +103,6 @@ class GLVQ(SupervisedPrototypeModel):
test_loss += batch_loss.item()
self.log("test_loss", test_loss)
# TODO
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting.
@@ -99,22 +112,28 @@ class SiameseGLVQ(GLVQ):
transformation pipeline are only learned from the inputs.
"""
def __init__(self,
def __init__(
self,
hparams,
backbone=torch.nn.Identity(),
both_path_gradients=False,
**kwargs):
**kwargs,
):
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone
self.both_path_gradients = both_path_gradients
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams.proto_lr)
proto_opt = self.optimizer(
self.proto_layer.parameters(),
lr=self.hparams["proto_lr"],
)
# Only add a backbone optimizer if backbone has trainable parameters
if (bb_params := list(self.backbone.parameters())):
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
bb_params = list(self.backbone.parameters())
if (bb_params):
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
optimizers = [proto_opt, bb_opt]
else:
optimizers = [proto_opt]
@@ -132,9 +151,13 @@ class SiameseGLVQ(GLVQ):
protos, _ = self.proto_layer()
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
self.backbone.requires_grad_(bb_grad)
distances = self.distance_layer(latent_x, latent_protos)
return distances
@@ -164,6 +187,7 @@ class LVQMLN(SiameseGLVQ):
rather in the embedding space.
"""
def compute_distances(self, x):
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
@@ -179,11 +203,13 @@ class GRLVQ(SiameseGLVQ):
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
"""
_relevances: torch.Tensor
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Additional parameters
relevances = torch.ones(self.hparams.input_dim, device=self.device)
relevances = torch.ones(self.hparams["input_dim"], device=self.device)
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone
@@ -204,22 +230,27 @@ class SiameseGMLVQ(SiameseGLVQ):
Implemented as a Siamese network with a linear transformation backbone.
"""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Override the backbone
self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim,
bias=False)
omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer())
self.backbone = LinearTransform(
self.hparams["input_dim"],
self.hparams["latent_dim"],
initializer=omega_initializer,
)
@property
def omega_matrix(self):
return self.backbone.weight.detach().cpu()
return self.backbone.weights
@property
def lambda_matrix(self):
omega = self.backbone.weight # (latent_dim, input_dim)
lam = omega.T @ omega
omega = self.backbone.weights # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
@@ -230,23 +261,39 @@ class GMLVQ(GLVQ):
function. This makes it easier to implement a localized variant.
"""
# Parameters
_omega: torch.Tensor
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
omega_initializer = kwargs.get("omega_initializer",
EyeTransformInitializer())
omega = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim)
omega_initializer = kwargs.get(
"omega_initializer",
EyeLinearTransformInitializer(),
)
omega = omega_initializer.generate(
self.hparams["input_dim"],
self.hparams["latent_dim"],
)
self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega,
name="omega matrix")
self.backbone = LambdaLayer(
lambda x: x @ self._omega,
name="omega matrix",
)
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach() # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega)
@@ -258,6 +305,7 @@ class GMLVQ(GLVQ):
class LGMLVQ(GMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", lomega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
@@ -265,15 +313,59 @@ class LGMLVQ(GMLVQ):
# Re-register `_omega` to override the one from the super class.
omega = torch.randn(
self.num_prototypes,
self.hparams.input_dim,
self.hparams.latent_dim,
self.hparams["input_dim"],
self.hparams["latent_dim"],
device=self.device,
)
self.register_parameter("_omega", Parameter(omega))
class GTLVQ(LGMLVQ):
"""Localized and Generalized Tangent Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
omega_initializer = kwargs.get("omega_initializer")
if omega_initializer is not None:
subspace = omega_initializer.generate(
self.hparams["input_dim"],
self.hparams["latent_dim"],
)
omega = torch.repeat_interleave(
subspace.unsqueeze(0),
self.num_prototypes,
dim=0,
)
else:
omega = torch.rand(
self.num_prototypes,
self.hparams["input_dim"],
self.hparams["latent_dim"],
device=self.device,
)
# Re-register `_omega` to override the one from the super class.
self.register_parameter("_omega", Parameter(omega))
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega))
class SiameseGTLVQ(SiameseGLVQ, GTLVQ):
"""Generalized Tangent Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
"""
class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = LossLayer(lvq1_loss)
@@ -282,6 +374,7 @@ class GLVQ1(GLVQ):
class GLVQ21(GLVQ):
"""Generalized Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = LossLayer(lvq21_loss)
@@ -304,3 +397,18 @@ class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
after updates.
"""
class ImageGTLVQ(ImagePrototypesMixin, GTLVQ):
"""GTLVQ for training on image data.
GTLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega))

View File

@@ -2,17 +2,22 @@
import warnings
from ..core.competitions import KNNC
from ..core.components import LabeledComponents
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
from ..utils.utils import parse_data_arg
from prototorch.core.competitions import KNNC
from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
LiteralCompInitializer,
LiteralLabelsInitializer,
)
from prototorch.utils.utils import parse_data_arg
from .abstract import SupervisedPrototypeModel
class KNN(SupervisedPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
super().__init__(hparams, skip_proto_layer=True, **kwargs)
# Default hparams
self.hparams.setdefault("k", 1)
@@ -24,7 +29,7 @@ class KNN(SupervisedPrototypeModel):
# Layers
self.proto_layer = LabeledComponents(
distribution=[],
distribution=len(data) * [1],
components_initializer=LiteralCompInitializer(data),
labels_initializer=LiteralLabelsInitializer(targets))
self.competition_layer = KNNC(k=self.hparams.k)
@@ -32,10 +37,7 @@ class KNN(SupervisedPrototypeModel):
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1 # skip training step
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
def on_train_batch_start(self, train_batch, batch_idx):
warnings.warn("k-NN has no training, skipping!")
return -1

View File

@@ -1,16 +1,21 @@
"""LVQ models that are optimized using non-gradient methods."""
from ..core.losses import _get_dp_dm
from .abstract import NonGradientMixin
import logging
from collections import OrderedDict
from prototorch.core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer
from .glvq import GLVQ
from .mixins import NonGradientMixin
class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos, plabels = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)
# TODO Vectorized implementation
@@ -24,12 +29,14 @@ class LVQ1(NonGradientMixin, GLVQ):
else:
shift = protos[w] - xi
updated_protos = protos + 0.0
updated_protos[w] = protos[w] + (self.hparams.lr * shift)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
updated_protos[w] = protos[w] + (self.hparams["lr"] * shift)
self.proto_layer.load_state_dict(
OrderedDict(_components=updated_protos),
strict=False,
)
print(f"{dis=}")
print(f"{y=}")
logging.debug(f"dis={dis}")
logging.debug(f"y={y}")
# Logging
self.log_acc(dis, y, tag="train_acc")
@@ -38,9 +45,9 @@ class LVQ1(NonGradientMixin, GLVQ):
class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
protos, plabels = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)
@@ -54,10 +61,12 @@ class LVQ21(NonGradientMixin, GLVQ):
shiftp = xi - protos[wp]
shiftn = protos[wn] - xi
updated_protos = protos + 0.0
updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
updated_protos[wp] = protos[wp] + (self.hparams["lr"] * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams["lr"] * shiftn)
self.proto_layer.load_state_dict(
OrderedDict(_components=updated_protos),
strict=False,
)
# Logging
self.log_acc(dis, y, tag="train_acc")
@@ -66,4 +75,64 @@ class LVQ21(NonGradientMixin, GLVQ):
class MedianLVQ(NonGradientMixin, GLVQ):
"""Median LVQ"""
"""Median LVQ
# TODO Avoid computing distances over and over
"""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.transfer_layer = LambdaLayer(
get_activation(self.hparams["transfer_fn"]))
def _f(self, x, y, protos, plabels):
d = self.distance_layer(x, protos)
dp, dm = _get_dp_dm(d, y, plabels, with_indices=False)
mu = (dp - dm) / (dp + dm)
negative_mu = -1.0 * mu
f = self.transfer_layer(
negative_mu,
beta=self.hparams["transfer_beta"],
) + 1.0
return f
def expectation(self, x, y, protos, plabels):
f = self._f(x, y, protos, plabels)
gamma = f / f.sum()
return gamma
def lower_bound(self, x, y, protos, plabels, gamma):
f = self._f(x, y, protos, plabels)
lower_bound = (gamma * f.log()).sum()
return lower_bound
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos, plabels = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)
for i, _ in enumerate(protos):
# Expectation step
gamma = self.expectation(x, y, protos, plabels)
lower_bound = self.lower_bound(x, y, protos, plabels, gamma)
# Maximization step
_protos = protos + 0
for k, xk in enumerate(x):
_protos[i] = xk
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
if _lower_bound > lower_bound:
logging.debug(f"Updating prototype {i} to data {k}...")
self.proto_layer.load_state_dict(
OrderedDict(_components=_protos),
strict=False,
)
break
# Logging
self.log_acc(dis, y, tag="train_acc")
return None

View File

@@ -0,0 +1,35 @@
import pytorch_lightning as pl
import torch
from prototorch.core.components import Components
class ProtoTorchMixin(pl.LightningModule):
"""All mixins are ProtoTorchMixins."""
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@@ -1,16 +1,20 @@
"""Probabilistic GLVQ methods"""
import torch
from prototorch.core.losses import nllr_loss, rslvq_loss
from prototorch.core.pooling import (
stratified_min_pooling,
stratified_sum_pooling,
)
from prototorch.nn.wrappers import LossLayer
from ..core.losses import nllr_loss, rslvq_loss
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
from ..nn.wrappers import LambdaLayer, LossLayer
from .extras import GaussianPrior, RankScaledGaussianPrior
from .glvq import GLVQ, SiameseGMLVQ
class CELVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@@ -20,29 +24,37 @@ class CELVQ(GLVQ):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x) # [None, num_protos]
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning
batch_loss = self.loss(probs, y.long())
loss = batch_loss.sum(dim=0)
loss = batch_loss.sum()
return out, loss
class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs)
self.conditional_distribution = None
self.rejection_confidence = rejection_confidence
self._conditional_distribution = None
def forward(self, x):
distances = self.compute_distances(x)
conditional = self.conditional_distribution(distances)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
device=self.device)
posterior = conditional * prior
plabels = self.proto_layer._labels
y_pred = stratified_sum_pooling(posterior, plabels)
if isinstance(plabels, torch.LongTensor) or isinstance(
plabels, torch.cuda.LongTensor): # type: ignore
y_pred = stratified_sum_pooling(posterior, plabels) # type: ignore
else:
raise ValueError("Labels must be LongTensor.")
return y_pred
def predict(self, x):
@@ -54,26 +66,44 @@ class ProbabilisticLVQ(GLVQ):
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.forward(x)
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum(dim=0)
loss = batch_loss.sum()
return loss
def conditional_distribution(self, distances):
"""Conditional distribution of distances."""
if self._conditional_distribution is None:
raise ValueError("Conditional distribution is not set.")
return self._conditional_distribution(distances)
class SLVQ(ProbabilisticLVQ):
"""Soft Learning Vector Quantization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self._conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(nllr_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class RSLVQ(ProbabilisticLVQ):
"""Robust Soft Learning Vector Quantization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self._conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(rslvq_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
@@ -81,10 +111,15 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
TODO: Use Backbone LVQ instead
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conditional_distribution = RankScaledGaussianPrior(
self.hparams.lambd)
# Default hparams
self.hparams.setdefault("lambda", 1.0)
lam = self.hparams.get("lambda", 1.0)
self.conditional_distribution = RankScaledGaussianPrior(lam)
self.loss = torch.nn.KLDivLoss()
# FIXME
@@ -92,5 +127,5 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
# x, y = batch
# y_pred = self(x)
# batch_loss = self.loss(y_pred, y)
# loss = batch_loss.sum(dim=0)
# loss = batch_loss.sum()
# return loss

View File

@@ -2,14 +2,14 @@
import numpy as np
import torch
from prototorch.core.competitions import wtac
from prototorch.core.distances import squared_euclidean_distance
from prototorch.core.losses import NeuralGasEnergy
from ..core.competitions import wtac
from ..core.distances import squared_euclidean_distance
from ..core.losses import NeuralGasEnergy
from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
from .abstract import UnsupervisedPrototypeModel
from .callbacks import GNGCallback
from .extras import ConnectionTopology
from .mixins import NonGradientMixin
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
@@ -18,6 +18,8 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
TODO Allow non-2D grids
"""
_grid: torch.Tensor
def __init__(self, hparams, **kwargs):
h, w = hparams.get("shape")
# Ignore `num_prototypes`
@@ -34,7 +36,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
# Additional parameters
x, y = torch.arange(h), torch.arange(w)
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
grid = torch.stack(torch.meshgrid(x, y, indexing="ij"), dim=-1)
self.register_buffer("_grid", grid)
self._sigma = self.hparams.sigma
self._lr = self.hparams.lr
@@ -53,12 +55,14 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
grid = self._grid.view(-1, 2)
gd = squared_euclidean_distance(wp, grid)
nh = torch.exp(-gd / self._sigma**2)
protos = self.proto_layer.components
protos = self.proto_layer()
diff = x.unsqueeze(dim=1) - protos
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
updated_protos = protos + delta.sum(dim=0)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
self.proto_layer.load_state_dict(
{"_components": updated_protos},
strict=False,
)
def training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp(
@@ -69,6 +73,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
class HeskesSOM(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@@ -78,6 +83,7 @@ class HeskesSOM(UnsupervisedPrototypeModel):
class NeuralGas(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@@ -85,13 +91,13 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("age_limit", 10)
self.hparams.setdefault("lm", 1)
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.energy_layer = NeuralGasEnergy(lm=self.hparams["lm"])
self.topology_layer = ConnectionTopology(
agelimit=self.hparams.agelimit,
num_prototypes=self.hparams.num_prototypes,
agelimit=self.hparams["age_limit"],
num_prototypes=self.hparams["num_prototypes"],
)
def training_step(self, train_batch, batch_idx):
@@ -104,12 +110,10 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.log("loss", loss)
return loss
# def training_epoch_end(self, training_step_outputs):
# print(f"{self.trainer.lr_schedulers}")
# print(f"{self.trainer.lr_schedulers[0]['scheduler'].optimizer}")
class GrowingNeuralGas(NeuralGas):
errors: torch.Tensor
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@@ -118,7 +122,10 @@ class GrowingNeuralGas(NeuralGas):
self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10)
errors = torch.zeros(self.hparams.num_prototypes, device=self.device)
errors = torch.zeros(
self.hparams["num_prototypes"],
device=self.device,
)
self.register_buffer("errors", errors)
def training_step(self, train_batch, _batch_idx):
@@ -132,8 +139,8 @@ class GrowingNeuralGas(NeuralGas):
mask[torch.arange(len(mask)), winner] = 1.0
dp = d * mask
self.errors += torch.sum(dp * dp, dim=0)
self.errors *= self.hparams.step_reduction
self.errors += torch.sum(dp * dp)
self.errors *= self.hparams["step_reduction"]
self.topology_layer(d)
self.log("loss", loss)
@@ -141,6 +148,8 @@ class GrowingNeuralGas(NeuralGas):
def configure_callbacks(self):
return [
GNGCallback(reduction=self.hparams.insert_reduction,
freq=self.hparams.insert_freq)
GNGCallback(
reduction=self.hparams["insert_reduction"],
freq=self.hparams["insert_freq"],
)
]

View File

@@ -1,20 +1,29 @@
"""Visualization Callbacks."""
import os
import warnings
from typing import Sized
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from matplotlib import pyplot as plt
from prototorch.utils.colors import get_colors, get_legend_handles
from prototorch.utils.utils import mesh2d
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, Dataset
from ..utils.utils import mesh2d
class Vis2DAbstract(pl.Callback):
def __init__(self,
data,
data=None,
title="Prototype Visualization",
cmap="viridis",
xlabel="Data dimension 1",
ylabel="Data dimension 2",
legend_labels=None,
border=0.1,
resolution=100,
flatten_data=True,
@@ -24,12 +33,22 @@ class Vis2DAbstract(pl.Callback):
tensorboard=False,
show_last_only=False,
pause_time=0.1,
save=False,
save_dir="./img",
fig_size=(5, 4),
dpi=500,
block=False):
super().__init__()
if data:
if isinstance(data, Dataset):
if isinstance(data, Sized):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
elif isinstance(data, torch.utils.data.DataLoader):
else:
# TODO: Add support for non-sized datasets
raise NotImplementedError(
"Data must be a dataset with a __len__ method.")
elif isinstance(data, DataLoader):
x = torch.tensor([])
y = torch.tensor([])
for x_b, y_b in data:
@@ -43,8 +62,14 @@ class Vis2DAbstract(pl.Callback):
self.x_train = x
self.y_train = y
else:
self.x_train = None
self.y_train = None
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
self.legend_labels = legend_labels
self.fig = plt.figure(self.title)
self.cmap = cmap
self.border = border
@@ -55,22 +80,28 @@ class Vis2DAbstract(pl.Callback):
self.tensorboard = tensorboard
self.show_last_only = show_last_only
self.pause_time = pause_time
self.save = save
self.save_dir = save_dir
self.fig_size = fig_size
self.dpi = dpi
self.block = block
if save:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
def precheck(self, trainer):
if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1:
return False
return True
def setup_ax(self, xlabel=None, ylabel=None):
def setup_ax(self):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
if xlabel:
ax.set_xlabel("Data dimension 1")
if ylabel:
ax.set_ylabel("Data dimension 2")
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.ylabel)
if self.axis_off:
ax.axis("off")
return ax
@@ -107,48 +138,58 @@ class Vis2DAbstract(pl.Callback):
def log_and_display(self, trainer, pl_module):
if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module)
if self.save:
plt.tight_layout()
self.fig.set_size_inches(*self.fig_size, forward=False)
plt.savefig(f"{self.save_dir}/{trainer.current_epoch}.png",
dpi=self.dpi)
if self.show:
if not self.block:
plt.pause(self.pause_time)
else:
plt.show(block=self.block)
def on_train_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
self.visualize(pl_module)
self.log_and_display(trainer, pl_module)
def on_train_end(self, trainer, pl_module):
plt.close()
def visualize(self, pl_module):
raise NotImplementedError
class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
self.border, self.resolution)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisSiameseGLVQ2D(Vis2DAbstract):
def __init__(self, *args, map_protos=True, **kwargs):
super().__init__(*args, **kwargs)
self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
@@ -175,18 +216,42 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos))
@@ -198,20 +263,15 @@ class VisCBC2D(Vis2DAbstract):
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
@@ -225,10 +285,27 @@ class VisNG2D(Vis2DAbstract):
"k-",
)
self.log_and_display(trainer, pl_module)
class VisSpectralProtos(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
ax = self.setup_ax()
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
for p, pl in zip(protos, plabels):
ax.plot(p, c=colors[int(pl)])
if self.legend_labels:
handles = get_legend_handles(
colors,
self.legend_labels,
marker="lines",
)
ax.legend(handles=handles)
class VisImgComp(Vis2DAbstract):
def __init__(self,
*args,
random_data=0,
@@ -244,32 +321,45 @@ class VisImgComp(Vis2DAbstract):
self.add_embedding = add_embedding
self.embedding_data = embedding_data
def on_train_start(self, trainer, pl_module):
def on_train_start(self, _, pl_module):
if isinstance(pl_module.logger, TensorBoardLogger):
tb = pl_module.logger.experiment
# Add embedding
if self.add_embedding:
if self.x_train is not None and self.y_train is not None:
ind = np.random.choice(len(self.x_train),
size=self.embedding_data,
replace=False)
data = self.x_train[ind]
# print(f"{data.shape=}")
# print(f"{self.y_train[ind].shape=}")
tb.add_embedding(data.view(len(ind), -1),
label_img=data,
global_step=None,
tag="Data Embedding",
metadata=self.y_train[ind],
metadata_header=None)
else:
raise ValueError("No data for add embedding flag")
# Random Data
if self.random_data:
if self.x_train is not None:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data = self.x_train[ind]
grid = torchvision.utils.make_grid(data, nrow=self.num_columns)
grid = torchvision.utils.make_grid(data,
nrow=self.num_columns)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=None,
dataformats=self.dataformats)
else:
raise ValueError("No data for random data flag")
else:
warnings.warn(
f"TensorBoardLogger is required, got {type(pl_module.logger)}")
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
@@ -283,14 +373,9 @@ class VisImgComp(Vis2DAbstract):
dataformats=self.dataformats,
)
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
if self.show:
components = pl_module.components
grid = torchvision.utils.make_grid(components,
nrow=self.num_columns)
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
self.log_and_display(trainer, pl_module)

23
prototorch/y/__init__.py Normal file
View File

@@ -0,0 +1,23 @@
from .architectures.base import BaseYArchitecture
from .architectures.comparison import (
OmegaComparisonMixin,
SimpleComparisonMixin,
)
from .architectures.competition import WTACompetitionMixin
from .architectures.components import SupervisedArchitecture
from .architectures.loss import GLVQLossMixin
from .architectures.optimization import (
MultipleLearningRateMixin,
SingleLearningRateMixin,
)
__all__ = [
'BaseYArchitecture',
"OmegaComparisonMixin",
"SimpleComparisonMixin",
"SingleLearningRateMixin",
"MultipleLearningRateMixin",
"SupervisedArchitecture",
"WTACompetitionMixin",
"GLVQLossMixin",
]

View File

@@ -0,0 +1,212 @@
"""
Proto Y Architecture
Network architecture for Component based Learning.
"""
from dataclasses import dataclass
from typing import (
Callable,
Dict,
Set,
Type,
)
import pytorch_lightning as pl
import torch
from torchmetrics import Metric
class BaseYArchitecture(pl.LightningModule):
@dataclass
class HyperParameters:
...
# Fields
registered_metrics: Dict[Type[Metric], Metric] = {}
registered_metric_callbacks: Dict[Type[Metric], Set[Callable]] = {}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# external API
def get_competition(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competition(batch, components)
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def forward_comparison(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
return self.get_competition(batch, components)
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competition(batch, components)
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Docs
def init_components(self, hparams: HyperParameters) -> None:
...
def init_latent(self, hparams: HyperParameters) -> None:
...
def init_comparison(self, hparams: HyperParameters) -> None:
...
def init_competition(self, hparams: HyperParameters) -> None:
...
def init_loss(self, hparams: HyperParameters) -> None:
...
def init_inference(self, hparams: HyperParameters) -> None:
...
# Empty Steps
# TODO: Type hints
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def latent(self, batch, components):
"""
The latent step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
Takes a batch of size N and the component set of size M.
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
def competition(self, comparison_measures, components):
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
def loss(self, comparison_measures, batch, components):
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
def inference(self, comparison_measures, components):
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
# Y Architecture Hooks
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: Callable,
metric: Type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[metric] = {name}
else:
self.registered_metric_callbacks[metric].add(name)
def update_metrics_step(self, batch):
# Prediction Metrics
preds = self(batch)
x, y = batch
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
instance(y, preds)
def update_metrics_epoch(self):
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance.compute()
for callback in self.registered_metric_callbacks[metric]:
callback(value, self)
instance.reset()
# Lightning Hooks
# Steps
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step([torch.clone(el) for el in batch])
return self.loss_forward(batch)
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)
# Other Hooks
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()

View File

@@ -0,0 +1,112 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable, Dict
import torch
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
AbstractLinearTransformInitializer,
EyeLinearTransformInitializer,
)
from prototorch.nn.wrappers import LambdaLayer
from prototorch.y.architectures.base import BaseYArchitecture
from torch import Tensor
from torch.nn.parameter import Parameter
class SimpleComparisonMixin(BaseYArchitecture):
"""
Simple Comparison
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
comparison_parameters: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn,
**hparams.comparison_args,
)
self.comparison_kwargs: dict[str, Tensor] = dict()
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(
batch_tensor,
comp_tensor,
**self.comparison_kwargs,
)
return distances
class OmegaComparisonMixin(SimpleComparisonMixin):
"""
Omega Comparison
A comparison layer that uses the positions of the components and the batch for dissimilarity computation.
"""
_omega: torch.Tensor
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(SimpleComparisonMixin.HyperParameters):
"""
input_dim: Necessary Field: The dimensionality of the input.
latent_dim: The dimensionality of the latent space. Default: 2.
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
"""
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters) -> None:
super().init_comparison(hparams)
# Initialize the omega matrix
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach()
lam = omega @ omega.T
return lam.detach().cpu()

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from prototorch.core.competitions import WTAC
from prototorch.y.architectures.base import BaseYArchitecture
class WTACompetitionMixin(BaseYArchitecture):
"""
Winner Take All Competition
A competition layer that uses the winner-take-all strategy.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
No hyperparameters.
"""
# Steps
# ----------------------------------------------------------------------------------------------------
def init_inference(self, hparams: HyperParameters):
self.competition_layer = WTAC()
def inference(self, comparison_measures, components):
comp_labels = components[1]
return self.competition_layer(comparison_measures, comp_labels)

View File

@@ -0,0 +1,53 @@
from dataclasses import dataclass
from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
)
from prototorch.y import BaseYArchitecture
class SupervisedArchitecture(BaseYArchitecture):
"""
Supervised Architecture
An architecture that uses labeled Components as component Layer.
"""
components_layer: LabeledComponents
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters:
"""
distribution: A valid prototype distribution. No default possible.
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
"""
distribution: "dict[str, int]"
component_initializer: AbstractComponentsInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_components(self, hparams: HyperParameters):
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def prototypes(self):
"""
Returns the position of the prototypes.
"""
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
"""
Returns the labels of the prototypes.
"""
return self.components_layer.labels.detach().cpu()

View File

@@ -0,0 +1,42 @@
from dataclasses import dataclass, field
from prototorch.core.losses import GLVQLoss
from prototorch.y.architectures.base import BaseYArchitecture
class GLVQLossMixin(BaseYArchitecture):
"""
GLVQ Loss
A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
margin: The margin of the GLVQ loss. Default: 0.0.
transfer_fn: Transfer function to use. Default: sigmoid_beta.
transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}.
"""
margin: float = 0.0
transfer_fn: str = "sigmoid_beta"
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
# Steps
# ----------------------------------------------------------------------------------------------------
def init_loss(self, hparams: HyperParameters):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
**hparams.transfer_args,
)
def loss(self, comparison_measures, batch, components):
target = batch[1]
comp_labels = components[1]
loss = self.loss_layer(comparison_measures, target, comp_labels)
self.log('loss', loss)
return loss

View File

@@ -0,0 +1,86 @@
from dataclasses import dataclass, field
from typing import Type
import torch
from prototorch.y import BaseYArchitecture
from torch.nn.parameter import Parameter
class SingleLearningRateMixin(BaseYArchitecture):
"""
Single Learning Rate
All parameters are updated with a single learning rate.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture):
"""
Multiple Learning Rates
Define Different Learning Rates for different parameters.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
optimizers = []
for name, lr in self.lr.items():
if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}")
else:
model_part = getattr(self, name)
if isinstance(model_part, Parameter):
optimizers.append(
self.optimizer(
[model_part],
lr=lr, # type: ignore
))
elif hasattr(model_part, "parameters"):
optimizers.append(
self.optimizer(
model_part.parameters(),
lr=lr, # type: ignore
))
return optimizers

218
prototorch/y/callbacks.py Normal file
View File

@@ -0,0 +1,218 @@
import warnings
from typing import Optional, Type
import numpy as np
import pytorch_lightning as pl
import torch
import torchmetrics
from matplotlib import pyplot as plt
from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d
from prototorch.y.architectures.base import BaseYArchitecture
from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.loggers import TensorBoardLogger
DIVERGING_COLOR_MAPS = [
'PiYG',
'PRGn',
'BrBG',
'PuOr',
'RdGy',
'RdBu',
'RdYlBu',
'RdYlGn',
'Spectral',
'coolwarm',
'bwr',
'seismic',
]
class LogTorchmetricCallback(pl.Callback):
def __init__(
self,
name,
metric: Type[torchmetrics.Metric],
on="prediction",
**metric_kwargs,
) -> None:
self.name = name
self.metric = metric
self.metric_kwargs = metric_kwargs
self.on = on
def setup(
self,
trainer: pl.Trainer,
pl_module: BaseYArchitecture,
stage: Optional[str] = None,
) -> None:
if self.on == "prediction":
pl_module.register_torchmetric(
self,
self.metric,
**self.metric_kwargs,
)
else:
raise ValueError(f"{self.on} is no valid metric hook")
def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value)
class LogConfusionMatrix(LogTorchmetricCallback):
def __init__(
self,
num_classes,
name="confusion",
on='prediction',
**kwargs,
):
super().__init__(
name,
torchmetrics.ConfusionMatrix,
on=on,
num_classes=num_classes,
**kwargs,
)
def __call__(self, value, pl_module: BaseYArchitecture):
fig, ax = plt.subplots()
ax.imshow(value.detach().cpu().numpy())
# Show all ticks and label them with the respective list entries
# ax.set_xticks(np.arange(len(farmers)), labels=farmers)
# ax.set_yticks(np.arange(len(vegetables)), labels=vegetables)
# Rotate the tick labels and set their alignment.
plt.setp(
ax.get_xticklabels(),
rotation=45,
ha="right",
rotation_mode="anchor",
)
# Loop over data dimensions and create text annotations.
for i in range(len(value)):
for j in range(len(value)):
text = ax.text(
j,
i,
value[i, j].item(),
ha="center",
va="center",
color="w",
)
ax.set_title(self.name)
fig.tight_layout()
pl_module.logger.experiment.add_figure(
tag=self.name,
figure=fig,
close=True,
global_step=pl_module.global_step,
)
class VisGLVQ2D(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(
np.vstack([x_train, protos]),
self.border,
self.resolution,
)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.components_layer.components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
class PlotLambdaMatrixToTensorboard(pl.Callback):
def __init__(self, cmap='seismic') -> None:
super().__init__()
self.cmap = cmap
if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str:
warnings.warn(
f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}"
)
def on_train_start(self, trainer, pl_module: GMLVQ):
self.plot_lambda(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
self.plot_lambda(trainer, pl_module)
def plot_lambda(self, trainer, pl_module: GMLVQ):
self.fig, self.ax = plt.subplots(1, 1)
# plot lambda matrix
l_matrix = pl_module.lambda_matrix
# normalize lambda matrix
l_matrix = l_matrix / torch.max(torch.abs(l_matrix))
# plot lambda matrix
self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1)
self.fig.colorbar(self.ax.images[-1])
# add title
self.ax.set_title('Lambda Matrix')
# add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure(
f"lambda_matrix",
self.fig,
trainer.global_step,
)
else:
warnings.warn(
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
)

View File

@@ -0,0 +1,5 @@
from .glvq import GLVQ
__all__ = [
"GLVQ",
]

View File

@@ -0,0 +1,35 @@
from dataclasses import dataclass
from prototorch.y import (
SimpleComparisonMixin,
SingleLearningRateMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
from prototorch.y.architectures.loss import GLVQLossMixin
class GLVQ(
SupervisedArchitecture,
SimpleComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
SingleLearningRateMixin,
):
"""
Generalized Learning Vector Quantization (GLVQ)
A GLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
@dataclass
class HyperParameters(
SimpleComparisonMixin.HyperParameters,
SingleLearningRateMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
No hyperparameters.
"""

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
import torch
from prototorch.core.distances import omega_distance
from prototorch.y import (
GLVQLossMixin,
MultipleLearningRateMixin,
OmegaComparisonMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
class GMLVQ(
SupervisedArchitecture,
OmegaComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
MultipleLearningRateMixin,
):
"""
Generalized Matrix Learning Vector Quantization (GMLVQ)
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(
MultipleLearningRateMixin.HyperParameters,
OmegaComparisonMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
"""
comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict())
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
lr: dict = field(default_factory=lambda: dict(
components_layer=0.1,
_omega=0.5,
))

View File

@@ -1,8 +1,23 @@
[isort]
profile = hug
src_paths = isort, test
[yapf]
based_on_style = pep8
spaces_before_comment = 2
split_before_logical_operator = true
[pylint]
disable =
too-many-arguments,
too-few-public-methods,
fixme,
[pycodestyle]
max-line-length = 79
[isort]
profile = hug
src_paths = isort, test
multi_line_output = 3
include_trailing_comma = True
force_grid_wrap = 3
use_parentheses = True
line_length = 79

View File

@@ -22,9 +22,10 @@ with open("README.md", "r") as fh:
long_description = fh.read()
INSTALL_REQUIRES = [
"prototorch>=0.6.0",
"pytorch_lightning>=1.3.5",
"prototorch>=0.7.3",
"pytorch_lightning>=1.6.0",
"torchmetrics",
"protobuf<3.20.0",
]
CLI = [
"jsonargparse",
@@ -37,6 +38,7 @@ DOCS = [
"recommonmark",
"sphinx",
"nbsphinx",
"ipykernel",
"sphinx_rtd_theme",
"sphinxcontrib-katex",
"sphinxcontrib-bibtex",
@@ -53,7 +55,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.2.0",
version="1.0.0-a3",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,
@@ -63,7 +65,7 @@ setup(
url=PROJECT_URL,
download_url=DOWNLOAD_URL,
license="MIT",
python_requires=">=3.9",
python_requires=">=3.7",
install_requires=INSTALL_REQUIRES,
extras_require={
"dev": DEV,
@@ -79,7 +81,11 @@ setup(
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",

View File

@@ -1,14 +0,0 @@
"""prototorch.models test suite."""
import unittest
class TestDummy(unittest.TestCase):
def setUp(self):
pass
def test_dummy(self):
pass
def tearDown(self):
pass

View File

@@ -1,11 +1,27 @@
#! /bin/bash
# Read Flags
gpu=0
while [ -n "$1" ]; do
case "$1" in
--gpu) gpu=1;;
-g) gpu=1;;
*) path=$1;;
esac
shift
done
python --version
echo "Using GPU: " $gpu
# Loop
failed=0
for example in $(find $1 -maxdepth 1 -name "*.py")
for example in $(find $path -maxdepth 1 -name "*.py")
do
echo -n "$x" $example '... '
export DISPLAY= && python $example --fast_dev_run 1 &> run_log.txt
export DISPLAY= && python $example --fast_dev_run 1 --gpus $gpu &> run_log.txt
if [[ $? -ne 0 ]]; then
echo "FAILED!!"
cat run_log.txt

195
tests/test_models.py Normal file
View File

@@ -0,0 +1,195 @@
"""prototorch.models test suite."""
import prototorch as pt
import pytest
import torch
def test_glvq_model_build():
model = pt.models.GLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_glvq1_model_build():
model = pt.models.GLVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_glvq21_model_build():
model = pt.models.GLVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_gmlvq_model_build():
model = pt.models.GMLVQ(
{
"distribution": (3, 2),
"input_dim": 2,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_grlvq_model_build():
model = pt.models.GRLVQ(
{
"distribution": (3, 2),
"input_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_gtlvq_model_build():
model = pt.models.GTLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_lgmlvq_model_build():
model = pt.models.LGMLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_image_glvq_model_build():
model = pt.models.ImageGLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_image_gmlvq_model_build():
model = pt.models.ImageGMLVQ(
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_image_gtlvq_model_build():
model = pt.models.ImageGMLVQ(
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_siamese_glvq_model_build():
model = pt.models.SiameseGLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_siamese_gmlvq_model_build():
model = pt.models.SiameseGMLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_siamese_gtlvq_model_build():
model = pt.models.SiameseGTLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_knn_model_build():
train_ds = pt.datasets.Iris(dims=[0, 2])
model = pt.models.KNN(dict(k=3), data=train_ds)
def test_lvq1_model_build():
model = pt.models.LVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_lvq21_model_build():
model = pt.models.LVQ21(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_median_lvq_model_build():
model = pt.models.MedianLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_celvq_model_build():
model = pt.models.CELVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_rslvq_model_build():
model = pt.models.RSLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_slvq_model_build():
model = pt.models.SLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_growing_neural_gas_model_build():
model = pt.models.GrowingNeuralGas(
{"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_kohonen_som_model_build():
model = pt.models.KohonenSOM(
{"shape": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_neural_gas_model_build():
model = pt.models.NeuralGas(
{"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2),
)