Compare commits

..

25 Commits

Author SHA1 Message Date
Alexander Engelsberger
d4bf6dbbe9
build: bump version 0.7.0 → 0.7.1 2023-10-25 15:56:53 +02:00
Alexander Engelsberger
c99fdb436c
ci: update action to pyproject toml workflow 2023-10-25 15:56:19 +02:00
Alexander Engelsberger
28ac5f5ed9
build: bump version 0.6.0 → 0.7.0 2023-10-25 15:19:04 +02:00
Alexander Engelsberger
b7f510a9fe
chore: update bumpversion config 2023-10-25 15:18:45 +02:00
Alexander Engelsberger
781ef93b06
ci: remove Python 3.12 2023-10-25 15:09:14 +02:00
Alexander Engelsberger
072e61b3cd
ci: Add Python 3.12 2023-10-25 15:04:05 +02:00
Alexander Engelsberger
71167a8f77
chore: remove optimizer_idx from all steps 2023-10-25 15:03:13 +02:00
Alexander Engelsberger
60990f42d2
fix: update import in tests 2023-06-20 21:18:28 +02:00
Alexander Engelsberger
1e83c439f7
ci: Trigger example test 2023-06-20 19:29:59 +02:00
Alexander Engelsberger
cbbbbeda98
fix: setuptools configuration 2023-06-20 19:25:35 +02:00
Alexander Engelsberger
1b5093627e
build: bump version 0.5.4 → 0.6.0 2023-06-20 18:50:03 +02:00
Alexander Engelsberger
497da90f9c
chore: small changes to configuration 2023-06-20 18:49:57 +02:00
Alexander Engelsberger
2a665e220f
fix: use multiclass accuracy by default 2023-06-20 18:30:18 +02:00
Alexander Engelsberger
4cd6aee330
chore: replace config by pyproject.toml 2023-06-20 18:30:05 +02:00
Alexander Engelsberger
634ef86a2c
fix: example test fixed 2023-06-20 17:42:36 +02:00
Alexander Engelsberger
72e9587a10
fix: remove removed CLI syntax from examples 2023-06-20 17:30:21 +02:00
Alexander Engelsberger
f5e1edf31f
ci: upgrade workflows 2023-06-20 16:39:13 +02:00
Alexander Engelsberger
5e5675d12e
ci: upgrade pre-commit config 2023-06-20 16:37:11 +02:00
Alexander Engelsberger
16f410e809
fix: style fixes 2023-03-09 15:59:49 +01:00
Alexander Engelsberger
46dfb82371
Fix: saving GMLVQ and GRLVQ fixed 2023-03-09 15:50:13 +01:00
Alexander Engelsberger
87fa3f0729
build: bump version 0.5.3 → 0.5.4 2023-03-02 17:29:54 +00:00
Alexander Engelsberger
08db94d507
fix: fix entrypoint configuration 2023-03-02 17:29:23 +00:00
Alexander Engelsberger
8ecf9948b2
build: bump version 0.5.2 → 0.5.3 2023-03-02 17:24:11 +00:00
Alexander Engelsberger
c5f0b86114
chore: upgrade pre commit 2023-03-02 17:23:41 +00:00
Alexander Engelsberger
7506614ada
fix: Update dependency versions 2023-03-02 17:05:39 +00:00
53 changed files with 592 additions and 1536 deletions

View File

@ -1,15 +1,13 @@
[bumpversion] [bumpversion]
current_version = 1.0.0a4 current_version = 0.7.1
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))? parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = serialize = {major}.{minor}.{patch}
{major}.{minor}.{patch}-{release}
{major}.{minor}.{patch}
message = build: bump version {current_version} → {new_version} message = build: bump version {current_version} → {new_version}
[bumpversion:file:setup.py] [bumpversion:file:pyproject.toml]
[bumpversion:file:./prototorch/models/__init__.py] [bumpversion:file:./src/prototorch/models/__init__.py]
[bumpversion:file:./docs/source/conf.py] [bumpversion:file:./docs/source/conf.py]

View File

@ -6,16 +6,16 @@ name: examples
on: on:
push: push:
paths: paths:
- 'examples/**.py' - "examples/**.py"
jobs: jobs:
cpu: cpu:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Set up Python 3.10 - name: Set up Python 3.11
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.11"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip

View File

@ -12,36 +12,36 @@ jobs:
style: style:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Set up Python 3.10 - name: Set up Python 3.11
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.11"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[all] pip install .[all]
- uses: pre-commit/action@v2.0.3 - uses: pre-commit/action@v3.0.0
compatibility: compatibility:
needs: style needs: style
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"] python-version: ["3.8", "3.9", "3.10", "3.11"]
os: [ubuntu-latest, windows-latest] os: [ubuntu-latest, windows-latest]
exclude: exclude:
- os: windows-latest
python-version: "3.7"
- os: windows-latest - os: windows-latest
python-version: "3.8" python-version: "3.8"
- os: windows-latest - os: windows-latest
python-version: "3.9" python-version: "3.9"
- os: windows-latest
python-version: "3.10"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
@ -56,18 +56,18 @@ jobs:
needs: compatibility needs: compatibility
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Set up Python 3.10 - name: Set up Python 3.11
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.11"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[all] pip install .[all]
pip install wheel pip install build
- name: Build package - name: Build package
run: python setup.py sdist bdist_wheel run: python -m build . -C verbose
- name: Publish a Python distribution to PyPI - name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1
with: with:

View File

@ -3,10 +3,9 @@
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0 rev: v4.4.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
exclude: (^\.bumpversion\.cfg$|cli_messages\.py)
- id: end-of-file-fixer - id: end-of-file-fixer
- id: check-yaml - id: check-yaml
- id: check-added-large-files - id: check-added-large-files
@ -14,17 +13,17 @@ repos:
- id: check-case-conflict - id: check-case-conflict
- repo: https://github.com/myint/autoflake - repo: https://github.com/myint/autoflake
rev: v1.4 rev: v2.1.1
hooks: hooks:
- id: autoflake - id: autoflake
- repo: http://github.com/PyCQA/isort - repo: http://github.com/PyCQA/isort
rev: 5.10.1 rev: 5.12.0
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.950 rev: v1.3.0
hooks: hooks:
- id: mypy - id: mypy
files: prototorch files: prototorch
@ -34,16 +33,17 @@ repos:
rev: v0.32.0 rev: v0.32.0
hooks: hooks:
- id: yapf - id: yapf
additional_dependencies: ["toml"]
- repo: https://github.com/pre-commit/pygrep-hooks - repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0 rev: v1.10.0
hooks: hooks:
- id: python-use-type-annotations - id: python-use-type-annotations
- id: python-no-log-warn - id: python-no-log-warn
- id: python-check-blanket-noqa - id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v2.32.1 rev: v3.7.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade

View File

@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "1.0.0-a4" release = "0.7.1"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@ -1,12 +1,11 @@
"""CBC example using the Iris dataset.""" """CBC example using the Iris dataset."""
import argparse import argparse
import warnings import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import CBC, VisCBC2D from prototorch.models import CBC, VisCBC2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -19,7 +18,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -53,8 +53,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -7,13 +7,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
CELVQ, CELVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisGLVQ2D, VisGLVQ2D,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -26,7 +26,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -83,8 +84,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -7,8 +7,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GLVQ, VisGLVQ2D from prototorch.models import GLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -21,7 +21,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -55,8 +56,10 @@ if __name__ == "__main__":
vis = VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GMLVQ, VisGMLVQ2D from prototorch.models import GMLVQ, VisGMLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -22,7 +22,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -59,8 +60,10 @@ if __name__ == "__main__":
vis = VisGMLVQ2D(data=train_ds) vis = VisGMLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],
@ -71,3 +74,5 @@ if __name__ == "__main__":
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
torch.save(model, "iris.pth")

View File

@ -6,13 +6,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
ImageGMLVQ, ImageGMLVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisImgComp, VisImgComp,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
@ -26,7 +26,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -96,8 +97,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -6,13 +6,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
GMLVQ, GMLVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisGLVQ2D, VisGLVQ2D,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -25,7 +25,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -78,8 +79,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@ -7,8 +7,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GrowingNeuralGas, VisNG2D from prototorch.models import GrowingNeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@ -51,8 +52,10 @@ if __name__ == "__main__":
vis = VisNG2D(data=train_loader) vis = VisNG2D(data=train_loader)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

77
examples/grlvq_iris.py Normal file
View File

@ -0,0 +1,77 @@
"""GMLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GRLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris([0, 1])
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
input_dim=2,
distribution={
"num_classes": 3,
"per_class": 2
},
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = GRLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisSiameseGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
max_epochs=5,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)
torch.save(model, "iris.pth")

View File

@ -6,13 +6,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
ImageGTLVQ, ImageGTLVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisImgComp, VisImgComp,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
@ -27,7 +27,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -100,8 +101,10 @@ if __name__ == "__main__":
# Setup trainer # Setup trainer
# using GPUs here is strongly recommended! # using GPUs here is strongly recommended!
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -7,9 +7,9 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GTLVQ, VisGLVQ2D from prototorch.models import GTLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@ -61,8 +62,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -59,8 +60,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
max_epochs=1, max_epochs=1,
callbacks=[ callbacks=[
vis, vis,

View File

@ -7,10 +7,10 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.models import KohonenSOM from prototorch.models import KohonenSOM
from prototorch.utils.colors import hex_to_rgb from prototorch.utils.colors import hex_to_rgb
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset
@ -58,7 +58,8 @@ class Vis2DColorSOM(pl.Callback):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@ -104,8 +105,10 @@ if __name__ == "__main__":
vis = Vis2DColorSOM(data=data) vis = Vis2DColorSOM(data=data)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
max_epochs=500, max_epochs=500,
callbacks=[ callbacks=[
vis, vis,

View File

@ -7,9 +7,9 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import LGMLVQ, VisGLVQ2D from prototorch.models import LGMLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@ -62,8 +63,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@ -6,12 +6,12 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
LVQMLN, LVQMLN,
PruneLoserPrototypes, PruneLoserPrototypes,
VisSiameseGLVQ2D, VisSiameseGLVQ2D,
) )
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -39,7 +39,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -88,8 +89,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -6,9 +6,9 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import MedianLVQ, VisGLVQ2D from prototorch.models import MedianLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -20,7 +20,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -53,8 +54,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import NeuralGas, VisNG2D from prototorch.models import NeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
@ -23,7 +23,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Prepare and pre-process the dataset # Prepare and pre-process the dataset
@ -60,8 +61,10 @@ if __name__ == "__main__":
vis = VisNG2D(data=train_ds) vis = VisNG2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import RSLVQ, VisGLVQ2D from prototorch.models import RSLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -17,7 +17,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@ -54,8 +55,10 @@ if __name__ == "__main__":
vis = VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -50,8 +51,7 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[1, 2, 3], distribution=[1, 2, 3],
proto_lr=0.01, lr=0.01,
bb_lr=0.01,
) )
# Initialize the backbone # Initialize the backbone
@ -69,8 +69,10 @@ if __name__ == "__main__":
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1) vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -50,8 +51,7 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[1, 2, 3], distribution=[1, 2, 3],
proto_lr=0.01, lr=0.01,
bb_lr=0.01,
input_dim=2, input_dim=2,
latent_dim=1, latent_dim=1,
) )
@ -71,8 +71,10 @@ if __name__ == "__main__":
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1) vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,6 +6,7 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
GLVQ, GLVQ,
KNN, KNN,
@ -14,7 +15,6 @@ from prototorch.models import (
VisGLVQ2D, VisGLVQ2D,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -27,7 +27,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Prepare the data # Prepare the data
@ -54,7 +55,9 @@ if __name__ == "__main__":
# Setup trainer for GNG # Setup trainer for GNG
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=1000, accelerator="cpu",
max_epochs=50 if args.fast_dev_run else
1000, # 10 epochs fast dev run reproducible DIV error.
callbacks=[ callbacks=[
es, es,
], ],
@ -108,8 +111,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -1,100 +0,0 @@
import prototorch as pt
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.y.callbacks import (
LogTorchmetricCallback,
PlotLambdaMatrixToTensorboard,
VisGMLVQ2D,
)
from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
# ##############################################################################
def main():
# ------------------------------------------------------------
# DATA
# ------------------------------------------------------------
# Dataset
train_ds = pt.datasets.Iris()
# Dataloader
train_loader = DataLoader(
train_ds,
batch_size=32,
num_workers=0,
shuffle=True,
)
# ------------------------------------------------------------
# HYPERPARAMETERS
# ------------------------------------------------------------
# Select Initializer
components_initializer = SMCI(train_ds)
# Define Hyperparameters
hyperparameters = GMLVQ.HyperParameters(
lr=dict(components_layer=0.1, _omega=0),
input_dim=4,
distribution=dict(
num_classes=3,
per_class=1,
),
component_initializer=components_initializer,
)
# Create Model
model = GMLVQ(hyperparameters)
print(model.hparams)
# ------------------------------------------------------------
# TRAINING
# ------------------------------------------------------------
# Controlling Callbacks
stopping_criterion = LogTorchmetricCallback(
'recall',
torchmetrics.Recall,
num_classes=3,
)
es = EarlyStopping(
monitor=stopping_criterion.name,
mode="max",
patience=10,
)
# Visualization Callback
vis = VisGMLVQ2D(data=train_ds)
# Define trainer
trainer = pl.Trainer(callbacks=[
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
], )
# Train
trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./y_arch.ckpt")
# Load saved model
new_model = GMLVQ.load_from_checkpoint(
checkpoint_path="./y_arch.ckpt",
strict=True,
)
print(new_model.hparams)
if __name__ == "__main__":
main()

View File

@ -1,35 +0,0 @@
import pytorch_lightning as pl
import torch
from prototorch.core.components import Components
class ProtoTorchMixin(pl.LightningModule):
"""All mixins are ProtoTorchMixins."""
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@ -1,23 +0,0 @@
from .architectures.base import BaseYArchitecture
from .architectures.comparison import (
OmegaComparisonMixin,
SimpleComparisonMixin,
)
from .architectures.competition import WTACompetitionMixin
from .architectures.components import SupervisedArchitecture
from .architectures.loss import GLVQLossMixin
from .architectures.optimization import (
MultipleLearningRateMixin,
SingleLearningRateMixin,
)
__all__ = [
'BaseYArchitecture',
"OmegaComparisonMixin",
"SimpleComparisonMixin",
"SingleLearningRateMixin",
"MultipleLearningRateMixin",
"SupervisedArchitecture",
"WTACompetitionMixin",
"GLVQLossMixin",
]

View File

@ -1,226 +0,0 @@
"""
Proto Y Architecture
Network architecture for Component based Learning.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
import pytorch_lightning as pl
import torch
from torchmetrics import Metric
class BaseYArchitecture(pl.LightningModule):
@dataclass
class HyperParameters:
...
# Fields
registered_metrics: dict[type[Metric], Metric] = {}
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
if type(hparams) is dict:
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams)
else:
self.save_hyperparameters(
hparams.__dict__,
ignore=["component_initializer"],
)
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# external API
def get_competition(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competition(batch, components)
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def forward_comparison(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
return self.get_competition(batch, components)
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competition(batch, components)
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Docs
def init_components(self, hparams: HyperParameters) -> None:
...
def init_latent(self, hparams: HyperParameters) -> None:
...
def init_comparison(self, hparams: HyperParameters) -> None:
...
def init_competition(self, hparams: HyperParameters) -> None:
...
def init_loss(self, hparams: HyperParameters) -> None:
...
def init_inference(self, hparams: HyperParameters) -> None:
...
# Empty Steps
# TODO: Type hints
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def latent(self, batch, components):
"""
The latent step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
Takes a batch of size N and the component set of size M.
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
def competition(self, comparison_measures, components):
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
def loss(self, comparison_measures, batch, components):
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
def inference(self, comparison_measures, components):
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
# Y Architecture Hooks
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: Callable,
metric: type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[metric] = {name}
else:
self.registered_metric_callbacks[metric].add(name)
def update_metrics_step(self, batch):
# Prediction Metrics
preds = self(batch)
x, y = batch
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
instance(y, preds)
def update_metrics_epoch(self):
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance.compute()
for callback in self.registered_metric_callbacks[metric]:
callback(value, self)
instance.reset()
# Lightning Hooks
# Steps
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step([torch.clone(el) for el in batch])
return self.loss_forward(batch)
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)
# Other Hooks
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}
return super().on_save_checkpoint(checkpoint)

View File

@ -1,112 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable, Dict
import torch
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
AbstractLinearTransformInitializer,
EyeLinearTransformInitializer,
)
from prototorch.nn.wrappers import LambdaLayer
from prototorch.y.architectures.base import BaseYArchitecture
from torch import Tensor
from torch.nn.parameter import Parameter
class SimpleComparisonMixin(BaseYArchitecture):
"""
Simple Comparison
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
comparison_parameters: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn,
**hparams.comparison_args,
)
self.comparison_kwargs: dict[str, Tensor] = dict()
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(
batch_tensor,
comp_tensor,
**self.comparison_kwargs,
)
return distances
class OmegaComparisonMixin(SimpleComparisonMixin):
"""
Omega Comparison
A comparison layer that uses the positions of the components and the batch for dissimilarity computation.
"""
_omega: torch.Tensor
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(SimpleComparisonMixin.HyperParameters):
"""
input_dim: Necessary Field: The dimensionality of the input.
latent_dim: The dimensionality of the latent space. Default: 2.
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
"""
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters) -> None:
super().init_comparison(hparams)
# Initialize the omega matrix
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach()
lam = omega @ omega.T
return lam.detach().cpu()

View File

@ -1,29 +0,0 @@
from dataclasses import dataclass
from prototorch.core.competitions import WTAC
from prototorch.y.architectures.base import BaseYArchitecture
class WTACompetitionMixin(BaseYArchitecture):
"""
Winner Take All Competition
A competition layer that uses the winner-take-all strategy.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
No hyperparameters.
"""
# Steps
# ----------------------------------------------------------------------------------------------------
def init_inference(self, hparams: HyperParameters):
self.competition_layer = WTAC()
def inference(self, comparison_measures, components):
comp_labels = components[1]
return self.competition_layer(comparison_measures, comp_labels)

View File

@ -1,64 +0,0 @@
from dataclasses import dataclass
from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
ZerosCompInitializer,
)
from prototorch.y import BaseYArchitecture
class SupervisedArchitecture(BaseYArchitecture):
"""
Supervised Architecture
An architecture that uses labeled Components as component Layer.
"""
components_layer: LabeledComponents
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters:
"""
distribution: A valid prototype distribution. No default possible.
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
"""
distribution: "dict[str, int]"
component_initializer: AbstractComponentsInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_components(self, hparams: HyperParameters):
if hparams.component_initializer is not None:
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
proto_shape = self.components_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
)
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def prototypes(self):
"""
Returns the position of the prototypes.
"""
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
"""
Returns the labels of the prototypes.
"""
return self.components_layer.labels.detach().cpu()

View File

@ -1,42 +0,0 @@
from dataclasses import dataclass, field
from prototorch.core.losses import GLVQLoss
from prototorch.y.architectures.base import BaseYArchitecture
class GLVQLossMixin(BaseYArchitecture):
"""
GLVQ Loss
A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
margin: The margin of the GLVQ loss. Default: 0.0.
transfer_fn: Transfer function to use. Default: sigmoid_beta.
transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}.
"""
margin: float = 0.0
transfer_fn: str = "sigmoid_beta"
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
# Steps
# ----------------------------------------------------------------------------------------------------
def init_loss(self, hparams: HyperParameters):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
**hparams.transfer_args,
)
def loss(self, comparison_measures, batch, components):
target = batch[1]
comp_labels = components[1]
loss = self.loss_layer(comparison_measures, target, comp_labels)
self.log('loss', loss)
return loss

View File

@ -1,73 +0,0 @@
from dataclasses import dataclass, field
from typing import Type
import torch
from prototorch.y import BaseYArchitecture
from torch.nn.parameter import Parameter
class SingleLearningRateMixin(BaseYArchitecture):
"""
Single Learning Rate
All parameters are updated with a single learning rate.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.hparams.optimizer(self.parameters(),
lr=self.hparams.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture):
"""
Multiple Learning Rates
Define Different Learning Rates for different parameters.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
optimizers = []
for name, lr in self.hparams.lr.items():
if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}")
else:
model_part = getattr(self, name)
if isinstance(model_part, Parameter):
optimizers.append(
self.hparams.optimizer(
[model_part],
lr=lr, # type: ignore
))
elif hasattr(model_part, "parameters"):
optimizers.append(
self.hparams.optimizer(
model_part.parameters(),
lr=lr, # type: ignore
))
return optimizers

View File

@ -1,218 +0,0 @@
import warnings
from typing import Optional, Type
import numpy as np
import pytorch_lightning as pl
import torch
import torchmetrics
from matplotlib import pyplot as plt
from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d
from prototorch.y.architectures.base import BaseYArchitecture
from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.loggers import TensorBoardLogger
DIVERGING_COLOR_MAPS = [
'PiYG',
'PRGn',
'BrBG',
'PuOr',
'RdGy',
'RdBu',
'RdYlBu',
'RdYlGn',
'Spectral',
'coolwarm',
'bwr',
'seismic',
]
class LogTorchmetricCallback(pl.Callback):
def __init__(
self,
name,
metric: Type[torchmetrics.Metric],
on="prediction",
**metric_kwargs,
) -> None:
self.name = name
self.metric = metric
self.metric_kwargs = metric_kwargs
self.on = on
def setup(
self,
trainer: pl.Trainer,
pl_module: BaseYArchitecture,
stage: Optional[str] = None,
) -> None:
if self.on == "prediction":
pl_module.register_torchmetric(
self,
self.metric,
**self.metric_kwargs,
)
else:
raise ValueError(f"{self.on} is no valid metric hook")
def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value)
class LogConfusionMatrix(LogTorchmetricCallback):
def __init__(
self,
num_classes,
name="confusion",
on='prediction',
**kwargs,
):
super().__init__(
name,
torchmetrics.ConfusionMatrix,
on=on,
num_classes=num_classes,
**kwargs,
)
def __call__(self, value, pl_module: BaseYArchitecture):
fig, ax = plt.subplots()
ax.imshow(value.detach().cpu().numpy())
# Show all ticks and label them with the respective list entries
# ax.set_xticks(np.arange(len(farmers)), labels=farmers)
# ax.set_yticks(np.arange(len(vegetables)), labels=vegetables)
# Rotate the tick labels and set their alignment.
plt.setp(
ax.get_xticklabels(),
rotation=45,
ha="right",
rotation_mode="anchor",
)
# Loop over data dimensions and create text annotations.
for i in range(len(value)):
for j in range(len(value)):
text = ax.text(
j,
i,
value[i, j].item(),
ha="center",
va="center",
color="w",
)
ax.set_title(self.name)
fig.tight_layout()
pl_module.logger.experiment.add_figure(
tag=self.name,
figure=fig,
close=True,
global_step=pl_module.global_step,
)
class VisGLVQ2D(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(
np.vstack([x_train, protos]),
self.border,
self.resolution,
)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.components_layer.components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
class PlotLambdaMatrixToTensorboard(pl.Callback):
def __init__(self, cmap='seismic') -> None:
super().__init__()
self.cmap = cmap
if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str:
warnings.warn(
f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}"
)
def on_train_start(self, trainer, pl_module: GMLVQ):
self.plot_lambda(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
self.plot_lambda(trainer, pl_module)
def plot_lambda(self, trainer, pl_module: GMLVQ):
self.fig, self.ax = plt.subplots(1, 1)
# plot lambda matrix
l_matrix = pl_module.lambda_matrix
# normalize lambda matrix
l_matrix = l_matrix / torch.max(torch.abs(l_matrix))
# plot lambda matrix
self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1)
self.fig.colorbar(self.ax.images[-1])
# add title
self.ax.set_title('Lambda Matrix')
# add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure(
f"lambda_matrix",
self.fig,
trainer.global_step,
)
else:
warnings.warn(
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
)

View File

@ -1,7 +0,0 @@
from .glvq import GLVQ
from .gmlvq import GMLVQ
__all__ = [
"GLVQ",
"GMLVQ",
]

View File

@ -1,35 +0,0 @@
from dataclasses import dataclass
from prototorch.y import (
SimpleComparisonMixin,
SingleLearningRateMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
from prototorch.y.architectures.loss import GLVQLossMixin
class GLVQ(
SupervisedArchitecture,
SimpleComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
SingleLearningRateMixin,
):
"""
Generalized Learning Vector Quantization (GLVQ)
A GLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
@dataclass
class HyperParameters(
SimpleComparisonMixin.HyperParameters,
SingleLearningRateMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
No hyperparameters.
"""

View File

@ -1,50 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
import torch
from prototorch.core.distances import omega_distance
from prototorch.y import (
GLVQLossMixin,
MultipleLearningRateMixin,
OmegaComparisonMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
class GMLVQ(
SupervisedArchitecture,
OmegaComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
MultipleLearningRateMixin,
):
"""
Generalized Matrix Learning Vector Quantization (GMLVQ)
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(
MultipleLearningRateMixin.HyperParameters,
OmegaComparisonMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
"""
comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict())
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
lr: dict = field(default_factory=lambda: dict(
components_layer=0.1,
_omega=0.5,
))

90
pyproject.toml Normal file
View File

@ -0,0 +1,90 @@
[project]
name = "prototorch-models"
version = "0.7.1"
description = "Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning."
authors = [
{ name = "Jensun Ravichandran", email = "jjensun@gmail.com" },
{ name = "Alexander Engelsberger", email = "engelsbe@hs-mittweida.de" },
]
dependencies = ["lightning>=2.0.0", "prototorch>=0.7.5"]
requires-python = ">=3.8"
readme = "README.md"
license = { text = "MIT" }
classifiers = [
"Development Status :: 2 - Pre-Alpha",
"Environment :: Plugins",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
]
[project.urls]
Homepage = "https://github.com/si-cim/prototorch_models"
Downloads = "https://github.com/si-cim/prototorch_models.git"
[project.optional-dependencies]
dev = ["bumpversion", "pre-commit", "yapf", "toml"]
examples = ["matplotlib", "scikit-learn"]
ci = ["pytest", "pre-commit"]
docs = [
"recommonmark",
"nbsphinx",
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-bibtex",
"sphinxcontrib-katex",
"ipykernel",
]
all = [
"bumpversion",
"pre-commit",
"yapf",
"toml",
"pytest",
"matplotlib",
"scikit-learn",
"recommonmark",
"nbsphinx",
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-bibtex",
"sphinxcontrib-katex",
"ipykernel",
]
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"
[tool.yapf]
based_on_style = "pep8"
spaces_before_comment = 2
split_before_logical_operator = true
[tool.pylint]
disable = ["too-many-arguments", "too-few-public-methods", "fixme"]
[tool.isort]
profile = "hug"
src_paths = ["isort", "test"]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 3
use_parentheses = true
line_length = 79
[tool.mypy]
explicit_package_bases = true
namespace_packages = true

View File

@ -1,23 +0,0 @@
[yapf]
based_on_style = pep8
spaces_before_comment = 2
split_before_logical_operator = true
[pylint]
disable =
too-many-arguments,
too-few-public-methods,
fixme,
[pycodestyle]
max-line-length = 79
[isort]
profile = hug
src_paths = isort, test
multi_line_output = 3
include_trailing_comma = True
force_grid_wrap = 3
use_parentheses = True
line_length = 79

View File

@ -1,99 +0,0 @@
"""
######
# # ##### #### ##### #### ##### #### ##### #### # #
# # # # # # # # # # # # # # # # # #
###### # # # # # # # # # # # # # ######
# ##### # # # # # # # # ##### # # #
# # # # # # # # # # # # # # # # #
# # # #### # #### # #### # # #### # #Plugin
ProtoTorch models Plugin Package
"""
from pkg_resources import safe_name
from setuptools import find_namespace_packages, setup
PLUGIN_NAME = "models"
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh:
long_description = fh.read()
INSTALL_REQUIRES = [
"prototorch>=0.7.3",
"pytorch_lightning>=1.6.0",
"torchmetrics",
"protobuf<3.20.0",
]
CLI = [
"jsonargparse",
]
DEV = [
"bumpversion",
"pre-commit",
]
DOCS = [
"recommonmark",
"sphinx",
"nbsphinx",
"ipykernel",
"sphinx_rtd_theme",
"sphinxcontrib-katex",
"sphinxcontrib-bibtex",
]
EXAMPLES = [
"matplotlib",
"scikit-learn",
]
TESTS = [
"codecov",
"pytest",
]
ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="1.0.0-a4",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,
long_description_content_type="text/markdown",
author="Alexander Engelsberger",
author_email="engelsbe@hs-mittweida.de",
url=PROJECT_URL,
download_url=DOWNLOAD_URL,
license="MIT",
python_requires=">=3.7",
install_requires=INSTALL_REQUIRES,
extras_require={
"dev": DEV,
"examples": EXAMPLES,
"tests": TESTS,
"all": ALL,
},
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Environment :: Plugins",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={
"prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}"
},
packages=find_namespace_packages(include=["prototorch.*"]),
zip_safe=False,
)

View File

@ -36,4 +36,4 @@ from .unsupervised import (
) )
from .vis import * from .vis import *
__version__ = "1.0.0-a4" __version__ = "0.7.1"

View File

@ -22,16 +22,7 @@ from prototorch.nn.wrappers import LambdaLayer
class ProtoTorchBolt(pl.LightningModule): class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts. """All ProtoTorch models are ProtoTorch Bolts."""
hparams:
- lr: learning rate
kwargs:
- optimizer: optimizer class
- lr_scheduler: learning rate scheduler class
- lr_scheduler_kwargs: learning rate scheduler kwargs
"""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__() super().__init__()
@ -74,18 +65,13 @@ class ProtoTorchBolt(pl.LightningModule):
class PrototypeModel(ProtoTorchBolt): class PrototypeModel(ProtoTorchBolt):
"""Abstract Prototype Model
kwargs:
- distance_fn: distance function
"""
proto_layer: AbstractComponents proto_layer: AbstractComponents
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance) distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn) self.distance_layer = LambdaLayer(distance_fn, name="distance_fn")
@property @property
def num_prototypes(self): def num_prototypes(self):
@ -200,20 +186,63 @@ class SupervisedPrototypeModel(PrototypeModel):
def log_acc(self, distances, targets, tag): def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances) preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) accuracy = torchmetrics.functional.accuracy(
# `.int()` because FloatTensors are assumed to be class probabilities preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(tag, self.log(
tag,
accuracy, accuracy,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
logger=True) logger=True,
)
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
x, targets = batch x, targets = batch
preds = self.predict(x) preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) accuracy = torchmetrics.functional.accuracy(
preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log("test_acc", accuracy) self.log("test_acc", accuracy)
class ProtoTorchMixin:
"""All mixins are ProtoTorchMixins."""
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@ -40,8 +40,8 @@ class PruneLoserPrototypes(pl.Callback):
return None return None
ratios = pl_module.prototype_win_ratios.mean(dim=0) ratios = pl_module.prototype_win_ratios.mean(dim=0)
to_prune_tensor = torch.arange(len(ratios))[ratios < self.threshold] to_prune = torch.arange(len(ratios))[ratios < self.threshold]
to_prune = to_prune_tensor.tolist() to_prune = to_prune.tolist()
prune_labels = pl_module.prototype_labels[to_prune] prune_labels = pl_module.prototype_labels[to_prune]
if self.prune_quota_per_epoch > 0: if self.prune_quota_per_epoch > 0:
to_prune = to_prune[:self.prune_quota_per_epoch] to_prune = to_prune[:self.prune_quota_per_epoch]

View File

@ -1,5 +1,4 @@
import torch import torch
import torch.nn.functional as F
import torchmetrics import torchmetrics
from prototorch.core.competitions import CBCC from prototorch.core.competitions import CBCC
from prototorch.core.components import ReasoningComponents from prototorch.core.components import ReasoningComponents
@ -8,13 +7,12 @@ from prototorch.core.losses import MarginLoss
from prototorch.core.similarities import euclidean_similarity from prototorch.core.similarities import euclidean_similarity
from prototorch.nn.wrappers import LambdaLayer from prototorch.nn.wrappers import LambdaLayer
from .abstract import ImagePrototypesMixin
from .glvq import SiameseGLVQ from .glvq import SiameseGLVQ
from .mixins import ImagePrototypesMixin
class CBC(SiameseGLVQ): class CBC(SiameseGLVQ):
"""Classification-By-Components.""" """Classification-By-Components."""
proto_layer: ReasoningComponents
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, skip_proto_layer=True, **kwargs) super().__init__(hparams, skip_proto_layer=True, **kwargs)
@ -24,7 +22,7 @@ class CBC(SiameseGLVQ):
reasonings_initializer = kwargs.get("reasonings_initializer", reasonings_initializer = kwargs.get("reasonings_initializer",
RandomReasoningsInitializer()) RandomReasoningsInitializer())
self.components_layer = ReasoningComponents( self.components_layer = ReasoningComponents(
self.hparams["distribution"], self.hparams.distribution,
components_initializer=components_initializer, components_initializer=components_initializer,
reasonings_initializer=reasonings_initializer, reasonings_initializer=reasonings_initializer,
) )
@ -34,7 +32,7 @@ class CBC(SiameseGLVQ):
# Namespace hook # Namespace hook
self.proto_layer = self.components_layer self.proto_layer = self.components_layer
self.loss = MarginLoss(self.hparams["margin"]) self.loss = MarginLoss(self.hparams.margin)
def forward(self, x): def forward(self, x):
components, reasonings = self.components_layer() components, reasonings = self.components_layer()
@ -46,25 +44,31 @@ class CBC(SiameseGLVQ):
probs = self.competition_layer(detections, reasonings) probs = self.competition_layer(detections, reasonings)
return probs return probs
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx):
x, y = batch x, y = batch
y_pred = self(x) y_pred = self(x)
num_classes = self.num_classes num_classes = self.num_classes
y_true = F.one_hot(y.long(), num_classes=num_classes) y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
loss = self.loss(y_pred, y_true).mean() loss = self.loss(y_pred, y_true).mean()
return y_pred, loss return y_pred, loss
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx):
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx) y_pred, train_loss = self.shared_step(batch, batch_idx)
preds = torch.argmax(y_pred, dim=1) preds = torch.argmax(y_pred, dim=1)
accuracy = torchmetrics.functional.accuracy(preds.int(), accuracy = torchmetrics.functional.accuracy(
batch[1].int()) preds.int(),
self.log("train_acc", batch[1].int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(
"train_acc",
accuracy, accuracy,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
logger=True) logger=True,
)
return train_loss return train_loss
def predict(self, x): def predict(self, x):

View File

@ -39,7 +39,7 @@ def ltangent_distance(x, y, omegas):
:param `torch.tensor` omegas: Three dimensional matrix :param `torch.tensor` omegas: Three dimensional matrix
:rtype: `torch.tensor` :rtype: `torch.tensor`
""" """
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm( p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
omegas, omegas.permute([0, 2, 1])) omegas, omegas.permute([0, 2, 1]))
projected_x = x @ p projected_x = x @ p

View File

@ -17,9 +17,8 @@ from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .abstract import SupervisedPrototypeModel from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization from .extras import ltangent_distance, orthogonalization
from .mixins import ImagePrototypesMixin
class GLVQ(SupervisedPrototypeModel): class GLVQ(SupervisedPrototypeModel):
@ -47,24 +46,19 @@ class GLVQ(SupervisedPrototypeModel):
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.register_buffer( self.register_buffer(
"prototype_win_ratios", "prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device), torch.zeros(self.num_prototypes, device=self.device))
)
def on_train_epoch_start(self): def on_train_epoch_start(self):
self.initialize_prototype_win_ratios() self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances): def log_prototype_win_ratios(self, distances):
batch_size = len(distances) batch_size = len(distances)
prototype_wc = torch.zeros( prototype_wc = torch.zeros(self.num_prototypes,
self.num_prototypes,
dtype=torch.long, dtype=torch.long,
device=self.device, device=self.device)
) wi, wc = torch.unique(distances.min(dim=-1).indices,
wi, wc = torch.unique(
distances.min(dim=-1).indices,
sorted=True, sorted=True,
return_counts=True, return_counts=True)
)
prototype_wc[wi] = wc prototype_wc[wi] = wc
prototype_wr = prototype_wc / batch_size prototype_wr = prototype_wc / batch_size
self.prototype_win_ratios = torch.vstack([ self.prototype_win_ratios = torch.vstack([
@ -72,27 +66,29 @@ class GLVQ(SupervisedPrototypeModel):
prototype_wr, prototype_wr,
]) ])
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx):
x, y = batch x, y = batch
out = self.compute_distances(x) out = self.compute_distances(x)
_, plabels = self.proto_layer() _, plabels = self.proto_layer()
loss = self.loss(out, y, plabels) loss = self.loss(out, y, plabels)
return out, loss return out, loss
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx) out, train_loss = self.shared_step(batch, batch_idx)
self.log_prototype_win_ratios(out) self.log_prototype_win_ratios(out)
self.log("train_loss", train_loss) self.log("train_loss", train_loss)
self.log_acc(out, batch[-1], tag="train_acc") self.log_acc(out, batch[-1], tag="train_acc")
return train_loss return train_loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, val_loss = self.shared_step(batch, batch_idx) out, val_loss = self.shared_step(batch, batch_idx)
self.log("val_loss", val_loss) self.log("val_loss", val_loss)
self.log_acc(out, batch[-1], tag="val_acc") self.log_acc(out, batch[-1], tag="val_acc")
return val_loss return val_loss
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, test_loss = self.shared_step(batch, batch_idx) out, test_loss = self.shared_step(batch, batch_idx)
self.log_acc(out, batch[-1], tag="test_acc") self.log_acc(out, batch[-1], tag="test_acc")
return test_loss return test_loss
@ -113,43 +109,19 @@ class SiameseGLVQ(GLVQ):
""" """
def __init__( def __init__(self,
self,
hparams, hparams,
backbone=torch.nn.Identity(), backbone=torch.nn.Identity(),
both_path_gradients=False, both_path_gradients=False,
**kwargs, **kwargs):
):
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance) distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone self.backbone = backbone
self.both_path_gradients = both_path_gradients self.both_path_gradients = both_path_gradients
def configure_optimizers(self):
proto_opt = self.optimizer(
self.proto_layer.parameters(),
lr=self.hparams["proto_lr"],
)
# Only add a backbone optimizer if backbone has trainable parameters
bb_params = list(self.backbone.parameters())
if (bb_params):
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
optimizers = [proto_opt, bb_opt]
else:
optimizers = [proto_opt]
if self.lr_scheduler is not None:
schedulers = []
for optimizer in optimizers:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
schedulers.append(scheduler)
return optimizers, schedulers
else:
return optimizers
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)] x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
latent_x = self.backbone(x) latent_x = self.backbone(x)
bb_grad = any([el.requires_grad for el in self.backbone.parameters()]) bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
@ -213,9 +185,12 @@ class GRLVQ(SiameseGLVQ):
self.register_parameter("_relevances", Parameter(relevances)) self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone # Override the backbone
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances), self.backbone = LambdaLayer(self._apply_relevances,
name="relevance scaling") name="relevance scaling")
def _apply_relevances(self, x):
return x @ torch.diag(self._relevances)
@property @property
def relevance_profile(self): def relevance_profile(self):
return self._relevances.detach().cpu() return self._relevances.detach().cpu()
@ -270,19 +245,11 @@ class GMLVQ(GLVQ):
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters # Additional parameters
omega_initializer = kwargs.get( omega_initializer = kwargs.get("omega_initializer",
"omega_initializer", EyeLinearTransformInitializer())
EyeLinearTransformInitializer(), omega = omega_initializer.generate(self.hparams["input_dim"],
) self.hparams["latent_dim"])
omega = omega_initializer.generate(
self.hparams["input_dim"],
self.hparams["latent_dim"],
)
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(
lambda x: x @ self._omega,
name="omega matrix",
)
@property @property
def omega_matrix(self): def omega_matrix(self):

View File

@ -34,7 +34,7 @@ class KNN(SupervisedPrototypeModel):
labels_initializer=LiteralLabelsInitializer(targets)) labels_initializer=LiteralLabelsInitializer(targets))
self.competition_layer = KNNC(k=self.hparams.k) self.competition_layer = KNNC(k=self.hparams.k)
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx):
return 1 # skip training step return 1 # skip training step
def on_train_batch_start(self, train_batch, batch_idx): def on_train_batch_start(self, train_batch, batch_idx):

View File

@ -1,21 +1,20 @@
"""LVQ models that are optimized using non-gradient methods.""" """LVQ models that are optimized using non-gradient methods."""
import logging import logging
from collections import OrderedDict
from prototorch.core.losses import _get_dp_dm from prototorch.core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation from prototorch.nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer from prototorch.nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin
from .glvq import GLVQ from .glvq import GLVQ
from .mixins import NonGradientMixin
class LVQ1(NonGradientMixin, GLVQ): class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1.""" """Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx):
protos, plabels = self.proto_layer() protos, plables = self.proto_layer()
x, y = train_batch x, y = train_batch
dis = self.compute_distances(x) dis = self.compute_distances(x)
# TODO Vectorized implementation # TODO Vectorized implementation
@ -29,11 +28,9 @@ class LVQ1(NonGradientMixin, GLVQ):
else: else:
shift = protos[w] - xi shift = protos[w] - xi
updated_protos = protos + 0.0 updated_protos = protos + 0.0
updated_protos[w] = protos[w] + (self.hparams["lr"] * shift) updated_protos[w] = protos[w] + (self.hparams.lr * shift)
self.proto_layer.load_state_dict( self.proto_layer.load_state_dict({"_components": updated_protos},
OrderedDict(_components=updated_protos), strict=False)
strict=False,
)
logging.debug(f"dis={dis}") logging.debug(f"dis={dis}")
logging.debug(f"y={y}") logging.debug(f"y={y}")
@ -46,7 +43,7 @@ class LVQ1(NonGradientMixin, GLVQ):
class LVQ21(NonGradientMixin, GLVQ): class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1.""" """Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx):
protos, plabels = self.proto_layer() protos, plabels = self.proto_layer()
x, y = train_batch x, y = train_batch
@ -61,12 +58,10 @@ class LVQ21(NonGradientMixin, GLVQ):
shiftp = xi - protos[wp] shiftp = xi - protos[wp]
shiftn = protos[wn] - xi shiftn = protos[wn] - xi
updated_protos = protos + 0.0 updated_protos = protos + 0.0
updated_protos[wp] = protos[wp] + (self.hparams["lr"] * shiftp) updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams["lr"] * shiftn) updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn)
self.proto_layer.load_state_dict( self.proto_layer.load_state_dict({"_components": updated_protos},
OrderedDict(_components=updated_protos), strict=False)
strict=False,
)
# Logging # Logging
self.log_acc(dis, y, tag="train_acc") self.log_acc(dis, y, tag="train_acc")
@ -85,17 +80,14 @@ class MedianLVQ(NonGradientMixin, GLVQ):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.transfer_layer = LambdaLayer( self.transfer_layer = LambdaLayer(
get_activation(self.hparams["transfer_fn"])) get_activation(self.hparams.transfer_fn))
def _f(self, x, y, protos, plabels): def _f(self, x, y, protos, plabels):
d = self.distance_layer(x, protos) d = self.distance_layer(x, protos)
dp, dm = _get_dp_dm(d, y, plabels, with_indices=False) dp, dm = _get_dp_dm(d, y, plabels)
mu = (dp - dm) / (dp + dm) mu = (dp - dm) / (dp + dm)
negative_mu = -1.0 * mu invmu = -1.0 * mu
f = self.transfer_layer( f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0
negative_mu,
beta=self.hparams["transfer_beta"],
) + 1.0
return f return f
def expectation(self, x, y, protos, plabels): def expectation(self, x, y, protos, plabels):
@ -108,7 +100,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
lower_bound = (gamma * f.log()).sum() lower_bound = (gamma * f.log()).sum()
return lower_bound return lower_bound
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx):
protos, plabels = self.proto_layer() protos, plabels = self.proto_layer()
x, y = train_batch x, y = train_batch
@ -126,10 +118,8 @@ class MedianLVQ(NonGradientMixin, GLVQ):
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma) _lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
if _lower_bound > lower_bound: if _lower_bound > lower_bound:
logging.debug(f"Updating prototype {i} to data {k}...") logging.debug(f"Updating prototype {i} to data {k}...")
self.proto_layer.load_state_dict( self.proto_layer.load_state_dict({"_components": _protos},
OrderedDict(_components=_protos), strict=False)
strict=False,
)
break break
# Logging # Logging

View File

@ -21,7 +21,7 @@ class CELVQ(GLVQ):
# Loss # Loss
self.loss = torch.nn.CrossEntropyLoss() self.loss = torch.nn.CrossEntropyLoss()
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx):
x, y = batch x, y = batch
out = self.compute_distances(x) # [None, num_protos] out = self.compute_distances(x) # [None, num_protos]
_, plabels = self.proto_layer() _, plabels = self.proto_layer()
@ -63,7 +63,7 @@ class ProbabilisticLVQ(GLVQ):
prediction[confidence < self.rejection_confidence] = -1 prediction[confidence < self.rejection_confidence] = -1
return prediction return prediction
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx):
x, y = batch x, y = batch
out = self.forward(x) out = self.forward(x)
_, plabels = self.proto_layer() _, plabels = self.proto_layer()
@ -123,7 +123,7 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
self.loss = torch.nn.KLDivLoss() self.loss = torch.nn.KLDivLoss()
# FIXME # FIXME
# def training_step(self, batch, batch_idx, optimizer_idx=None): # def training_step(self, batch, batch_idx):
# x, y = batch # x, y = batch
# y_pred = self(x) # y_pred = self(x)
# batch_loss = self.loss(y_pred, y) # batch_loss = self.loss(y_pred, y)

View File

@ -6,10 +6,9 @@ from prototorch.core.competitions import wtac
from prototorch.core.distances import squared_euclidean_distance from prototorch.core.distances import squared_euclidean_distance
from prototorch.core.losses import NeuralGasEnergy from prototorch.core.losses import NeuralGasEnergy
from .abstract import UnsupervisedPrototypeModel from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
from .callbacks import GNGCallback from .callbacks import GNGCallback
from .extras import ConnectionTopology from .extras import ConnectionTopology
from .mixins import NonGradientMixin
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
@ -64,7 +63,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
strict=False, strict=False,
) )
def training_epoch_end(self, training_step_outputs): def on_training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp( self._sigma = self.hparams.sigma * np.exp(
-self.current_epoch / self.trainer.max_epochs) -self.current_epoch / self.trainer.max_epochs)

View File

@ -1,6 +1,5 @@
"""Visualization Callbacks.""" """Visualization Callbacks."""
import os
import warnings import warnings
from typing import Sized from typing import Sized
@ -33,10 +32,6 @@ class Vis2DAbstract(pl.Callback):
tensorboard=False, tensorboard=False,
show_last_only=False, show_last_only=False,
pause_time=0.1, pause_time=0.1,
save=False,
save_dir="./img",
fig_size=(5, 4),
dpi=500,
block=False): block=False):
super().__init__() super().__init__()
@ -80,16 +75,8 @@ class Vis2DAbstract(pl.Callback):
self.tensorboard = tensorboard self.tensorboard = tensorboard
self.show_last_only = show_last_only self.show_last_only = show_last_only
self.pause_time = pause_time self.pause_time = pause_time
self.save = save
self.save_dir = save_dir
self.fig_size = fig_size
self.dpi = dpi
self.block = block self.block = block
if save:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
def precheck(self, trainer): def precheck(self, trainer):
if self.show_last_only: if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1: if trainer.current_epoch != trainer.max_epochs - 1:
@ -138,11 +125,6 @@ class Vis2DAbstract(pl.Callback):
def log_and_display(self, trainer, pl_module): def log_and_display(self, trainer, pl_module):
if self.tensorboard: if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module) self.add_to_tensorboard(trainer, pl_module)
if self.save:
plt.tight_layout()
self.fig.set_size_inches(*self.fig_size, forward=False)
plt.savefig(f"{self.save_dir}/{trainer.current_epoch}.png",
dpi=self.dpi)
if self.show: if self.show:
if not self.block: if not self.block:
plt.pause(self.pause_time) plt.pause(self.pause_time)

View File

@ -1,195 +1,193 @@
"""prototorch.models test suite.""" """prototorch.models test suite."""
import prototorch as pt import prototorch.models
import pytest
import torch
def test_glvq_model_build(): def test_glvq_model_build():
model = pt.models.GLVQ( model = prototorch.models.GLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_glvq1_model_build(): def test_glvq1_model_build():
model = pt.models.GLVQ1( model = prototorch.models.GLVQ1(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_glvq21_model_build(): def test_glvq21_model_build():
model = pt.models.GLVQ1( model = prototorch.models.GLVQ1(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_gmlvq_model_build(): def test_gmlvq_model_build():
model = pt.models.GMLVQ( model = prototorch.models.GMLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 2, "input_dim": 2,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_grlvq_model_build(): def test_grlvq_model_build():
model = pt.models.GRLVQ( model = prototorch.models.GRLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 2, "input_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_gtlvq_model_build(): def test_gtlvq_model_build():
model = pt.models.GTLVQ( model = prototorch.models.GTLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 4, "input_dim": 4,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_lgmlvq_model_build(): def test_lgmlvq_model_build():
model = pt.models.LGMLVQ( model = prototorch.models.LGMLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 4, "input_dim": 4,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_image_glvq_model_build(): def test_image_glvq_model_build():
model = pt.models.ImageGLVQ( model = prototorch.models.ImageGLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(16), prototypes_initializer=prototorch.initializers.RNCI(16),
) )
def test_image_gmlvq_model_build(): def test_image_gmlvq_model_build():
model = pt.models.ImageGMLVQ( model = prototorch.models.ImageGMLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 16, "input_dim": 16,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(16), prototypes_initializer=prototorch.initializers.RNCI(16),
) )
def test_image_gtlvq_model_build(): def test_image_gtlvq_model_build():
model = pt.models.ImageGMLVQ( model = prototorch.models.ImageGMLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 16, "input_dim": 16,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(16), prototypes_initializer=prototorch.initializers.RNCI(16),
) )
def test_siamese_glvq_model_build(): def test_siamese_glvq_model_build():
model = pt.models.SiameseGLVQ( model = prototorch.models.SiameseGLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(4), prototypes_initializer=prototorch.initializers.RNCI(4),
) )
def test_siamese_gmlvq_model_build(): def test_siamese_gmlvq_model_build():
model = pt.models.SiameseGMLVQ( model = prototorch.models.SiameseGMLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 4, "input_dim": 4,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(4), prototypes_initializer=prototorch.initializers.RNCI(4),
) )
def test_siamese_gtlvq_model_build(): def test_siamese_gtlvq_model_build():
model = pt.models.SiameseGTLVQ( model = prototorch.models.SiameseGTLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 4, "input_dim": 4,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(4), prototypes_initializer=prototorch.initializers.RNCI(4),
) )
def test_knn_model_build(): def test_knn_model_build():
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = prototorch.datasets.Iris(dims=[0, 2])
model = pt.models.KNN(dict(k=3), data=train_ds) model = prototorch.models.KNN(dict(k=3), data=train_ds)
def test_lvq1_model_build(): def test_lvq1_model_build():
model = pt.models.LVQ1( model = prototorch.models.LVQ1(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_lvq21_model_build(): def test_lvq21_model_build():
model = pt.models.LVQ21( model = prototorch.models.LVQ21(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_median_lvq_model_build(): def test_median_lvq_model_build():
model = pt.models.MedianLVQ( model = prototorch.models.MedianLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_celvq_model_build(): def test_celvq_model_build():
model = pt.models.CELVQ( model = prototorch.models.CELVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_rslvq_model_build(): def test_rslvq_model_build():
model = pt.models.RSLVQ( model = prototorch.models.RSLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_slvq_model_build(): def test_slvq_model_build():
model = pt.models.SLVQ( model = prototorch.models.SLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_growing_neural_gas_model_build(): def test_growing_neural_gas_model_build():
model = pt.models.GrowingNeuralGas( model = prototorch.models.GrowingNeuralGas(
{"num_prototypes": 5}, {"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_kohonen_som_model_build(): def test_kohonen_som_model_build():
model = pt.models.KohonenSOM( model = prototorch.models.KohonenSOM(
{"shape": (3, 2)}, {"shape": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_neural_gas_model_build(): def test_neural_gas_model_build():
model = pt.models.NeuralGas( model = prototorch.models.NeuralGas(
{"num_prototypes": 5}, {"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )