32 Commits
dev ... GMLMLVQ

Author SHA1 Message Date
julius
adafb49985 masks -> ParameterList(requires_grad=False) 2023-11-07 19:17:43 +01:00
julius
78f8b6cc00 remove accidental LiteralString import 2023-11-07 18:52:51 +01:00
julius
c6f718a1d4 GMLMLVQ: allow for 2 or more omega layers 2023-11-07 16:44:13 +01:00
julius
1786031b4e adjust omega_matrix property 2023-11-06 16:32:57 +01:00
julius
824dfced92 Implement a prototypical 2-layer version of GMLVQ 2023-11-03 14:59:00 +01:00
Alexander Engelsberger
d4bf6dbbe9 build: bump version 0.7.0 → 0.7.1 2023-10-25 15:56:53 +02:00
Alexander Engelsberger
c99fdb436c ci: update action to pyproject toml workflow 2023-10-25 15:56:19 +02:00
Alexander Engelsberger
28ac5f5ed9 build: bump version 0.6.0 → 0.7.0 2023-10-25 15:19:04 +02:00
Alexander Engelsberger
b7f510a9fe chore: update bumpversion config 2023-10-25 15:18:45 +02:00
Alexander Engelsberger
781ef93b06 ci: remove Python 3.12 2023-10-25 15:09:14 +02:00
Alexander Engelsberger
072e61b3cd ci: Add Python 3.12 2023-10-25 15:04:05 +02:00
Alexander Engelsberger
71167a8f77 chore: remove optimizer_idx from all steps 2023-10-25 15:03:13 +02:00
Alexander Engelsberger
60990f42d2 fix: update import in tests 2023-06-20 21:18:28 +02:00
Alexander Engelsberger
1e83c439f7 ci: Trigger example test 2023-06-20 19:29:59 +02:00
Alexander Engelsberger
cbbbbeda98 fix: setuptools configuration 2023-06-20 19:25:35 +02:00
Alexander Engelsberger
1b5093627e build: bump version 0.5.4 → 0.6.0 2023-06-20 18:50:03 +02:00
Alexander Engelsberger
497da90f9c chore: small changes to configuration 2023-06-20 18:49:57 +02:00
Alexander Engelsberger
2a665e220f fix: use multiclass accuracy by default 2023-06-20 18:30:18 +02:00
Alexander Engelsberger
4cd6aee330 chore: replace config by pyproject.toml 2023-06-20 18:30:05 +02:00
Alexander Engelsberger
634ef86a2c fix: example test fixed 2023-06-20 17:42:36 +02:00
Alexander Engelsberger
72e9587a10 fix: remove removed CLI syntax from examples 2023-06-20 17:30:21 +02:00
Alexander Engelsberger
f5e1edf31f ci: upgrade workflows 2023-06-20 16:39:13 +02:00
Alexander Engelsberger
5e5675d12e ci: upgrade pre-commit config 2023-06-20 16:37:11 +02:00
Alexander Engelsberger
16f410e809 fix: style fixes 2023-03-09 15:59:49 +01:00
Alexander Engelsberger
46dfb82371 Fix: saving GMLVQ and GRLVQ fixed 2023-03-09 15:50:13 +01:00
Alexander Engelsberger
87fa3f0729 build: bump version 0.5.3 → 0.5.4 2023-03-02 17:29:54 +00:00
Alexander Engelsberger
08db94d507 fix: fix entrypoint configuration 2023-03-02 17:29:23 +00:00
Alexander Engelsberger
8ecf9948b2 build: bump version 0.5.2 → 0.5.3 2023-03-02 17:24:11 +00:00
Alexander Engelsberger
c5f0b86114 chore: upgrade pre commit 2023-03-02 17:23:41 +00:00
Alexander Engelsberger
7506614ada fix: Update dependency versions 2023-03-02 17:05:39 +00:00
Alexander Engelsberger
fcd944d3ff build: bump version 0.5.1 → 0.5.2 2022-06-01 14:25:44 +02:00
Alexander Engelsberger
054720dd7b fix(hotfix): Protobuf error workaround 2022-06-01 14:14:57 +02:00
40 changed files with 592 additions and 447 deletions

View File

@@ -1,13 +1,13 @@
[bumpversion] [bumpversion]
current_version = 0.5.1 current_version = 0.7.1
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = {major}.{minor}.{patch} serialize = {major}.{minor}.{patch}
message = build: bump version {current_version} → {new_version} message = build: bump version {current_version} → {new_version}
[bumpversion:file:setup.py] [bumpversion:file:pyproject.toml]
[bumpversion:file:./prototorch/models/__init__.py] [bumpversion:file:./src/prototorch/models/__init__.py]
[bumpversion:file:./docs/source/conf.py] [bumpversion:file:./docs/source/conf.py]

View File

@@ -6,16 +6,16 @@ name: examples
on: on:
push: push:
paths: paths:
- 'examples/**.py' - "examples/**.py"
jobs: jobs:
cpu: cpu:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Set up Python 3.10 - name: Set up Python 3.11
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.11"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip

View File

@@ -6,42 +6,42 @@ name: tests
on: on:
push: push:
pull_request: pull_request:
branches: [ master ] branches: [master]
jobs: jobs:
style: style:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Set up Python 3.10 - name: Set up Python 3.11
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.11"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[all] pip install .[all]
- uses: pre-commit/action@v2.0.3 - uses: pre-commit/action@v3.0.0
compatibility: compatibility:
needs: style needs: style
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"] python-version: ["3.8", "3.9", "3.10", "3.11"]
os: [ubuntu-latest, windows-latest] os: [ubuntu-latest, windows-latest]
exclude: exclude:
- os: windows-latest
python-version: "3.7"
- os: windows-latest - os: windows-latest
python-version: "3.8" python-version: "3.8"
- os: windows-latest - os: windows-latest
python-version: "3.9" python-version: "3.9"
- os: windows-latest
python-version: "3.10"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
@@ -56,18 +56,18 @@ jobs:
needs: compatibility needs: compatibility
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Set up Python 3.10 - name: Set up Python 3.11
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.11"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[all] pip install .[all]
pip install wheel pip install build
- name: Build package - name: Build package
run: python setup.py sdist bdist_wheel run: python -m build . -C verbose
- name: Publish a Python distribution to PyPI - name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1
with: with:

View File

@@ -2,8 +2,8 @@
# See https://pre-commit.com/hooks.html for more hooks # See https://pre-commit.com/hooks.html for more hooks
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0 rev: v4.4.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: end-of-file-fixer - id: end-of-file-fixer
@@ -12,41 +12,42 @@ repos:
- id: check-ast - id: check-ast
- id: check-case-conflict - id: check-case-conflict
- repo: https://github.com/myint/autoflake - repo: https://github.com/myint/autoflake
rev: v1.4 rev: v2.1.1
hooks: hooks:
- id: autoflake - id: autoflake
- repo: http://github.com/PyCQA/isort - repo: http://github.com/PyCQA/isort
rev: 5.10.1 rev: 5.12.0
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.950 rev: v1.3.0
hooks: hooks:
- id: mypy - id: mypy
files: prototorch files: prototorch
additional_dependencies: [types-pkg_resources] additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0 rev: v0.32.0
hooks: hooks:
- id: yapf - id: yapf
additional_dependencies: ["toml"]
- repo: https://github.com/pre-commit/pygrep-hooks - repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0 rev: v1.10.0
hooks: hooks:
- id: python-use-type-annotations - id: python-use-type-annotations
- id: python-no-log-warn - id: python-no-log-warn
- id: python-check-blanket-noqa - id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v2.32.1 rev: v3.7.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade
- repo: https://github.com/si-cim/gitlint - repo: https://github.com/si-cim/gitlint
rev: v0.15.2-unofficial rev: v0.15.2-unofficial
hooks: hooks:
- id: gitlint - id: gitlint

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.5.1" release = "0.7.1"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -1,12 +1,11 @@
"""CBC example using the Iris dataset.""" """CBC example using the Iris dataset."""
import argparse import argparse
import warnings import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import CBC, VisCBC2D from prototorch.models import CBC, VisCBC2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -19,7 +18,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -53,8 +53,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@@ -7,13 +7,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
CELVQ, CELVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisGLVQ2D, VisGLVQ2D,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -26,7 +26,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -83,8 +84,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@@ -7,8 +7,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GLVQ, VisGLVQ2D from prototorch.models import GLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -21,7 +21,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -55,8 +56,10 @@ if __name__ == "__main__":
vis = VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GMLVQ, VisGMLVQ2D from prototorch.models import GMLVQ, VisGMLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -22,7 +22,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -59,8 +60,10 @@ if __name__ == "__main__":
vis = VisGMLVQ2D(data=train_ds) vis = VisGMLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],
@@ -71,3 +74,5 @@ if __name__ == "__main__":
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
torch.save(model, "iris.pth")

View File

@@ -6,13 +6,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
ImageGMLVQ, ImageGMLVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisImgComp, VisImgComp,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
@@ -26,7 +26,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -96,8 +97,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@@ -6,13 +6,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
GMLVQ, GMLVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisGLVQ2D, VisGLVQ2D,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -25,7 +25,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -78,8 +79,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@@ -7,8 +7,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GrowingNeuralGas, VisNG2D from prototorch.models import GrowingNeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@@ -51,8 +52,10 @@ if __name__ == "__main__":
vis = VisNG2D(data=train_loader) vis = VisNG2D(data=train_loader)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

77
examples/grlvq_iris.py Normal file
View File

@@ -0,0 +1,77 @@
"""GMLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GRLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris([0, 1])
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
input_dim=2,
distribution={
"num_classes": 3,
"per_class": 2
},
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = GRLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisSiameseGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
max_epochs=5,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)
torch.save(model, "iris.pth")

View File

@@ -6,13 +6,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
ImageGTLVQ, ImageGTLVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisImgComp, VisImgComp,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
@@ -27,7 +27,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -100,8 +101,10 @@ if __name__ == "__main__":
# Setup trainer # Setup trainer
# using GPUs here is strongly recommended! # using GPUs here is strongly recommended!
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@@ -7,9 +7,9 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GTLVQ, VisGLVQ2D from prototorch.models import GTLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@@ -61,8 +62,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -59,8 +60,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
max_epochs=1, max_epochs=1,
callbacks=[ callbacks=[
vis, vis,

View File

@@ -7,10 +7,10 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.models import KohonenSOM from prototorch.models import KohonenSOM
from prototorch.utils.colors import hex_to_rgb from prototorch.utils.colors import hex_to_rgb
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset
@@ -58,7 +58,8 @@ class Vis2DColorSOM(pl.Callback):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@@ -104,8 +105,10 @@ if __name__ == "__main__":
vis = Vis2DColorSOM(data=data) vis = Vis2DColorSOM(data=data)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
max_epochs=500, max_epochs=500,
callbacks=[ callbacks=[
vis, vis,

View File

@@ -7,9 +7,9 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import LGMLVQ, VisGLVQ2D from prototorch.models import LGMLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@@ -62,8 +63,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@@ -6,12 +6,12 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
LVQMLN, LVQMLN,
PruneLoserPrototypes, PruneLoserPrototypes,
VisSiameseGLVQ2D, VisSiameseGLVQ2D,
) )
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -39,7 +39,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -88,8 +89,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@@ -6,9 +6,9 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import MedianLVQ, VisGLVQ2D from prototorch.models import MedianLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -20,7 +20,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -53,8 +54,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import NeuralGas, VisNG2D from prototorch.models import NeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
@@ -23,7 +23,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Prepare and pre-process the dataset # Prepare and pre-process the dataset
@@ -60,8 +61,10 @@ if __name__ == "__main__":
vis = VisNG2D(data=train_ds) vis = VisNG2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import RSLVQ, VisGLVQ2D from prototorch.models import RSLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -17,7 +17,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@@ -54,8 +55,10 @@ if __name__ == "__main__":
vis = VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -50,8 +51,7 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[1, 2, 3], distribution=[1, 2, 3],
proto_lr=0.01, lr=0.01,
bb_lr=0.01,
) )
# Initialize the backbone # Initialize the backbone
@@ -69,8 +69,10 @@ if __name__ == "__main__":
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1) vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@@ -50,8 +51,7 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[1, 2, 3], distribution=[1, 2, 3],
proto_lr=0.01, lr=0.01,
bb_lr=0.01,
input_dim=2, input_dim=2,
latent_dim=1, latent_dim=1,
) )
@@ -71,8 +71,10 @@ if __name__ == "__main__":
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1) vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@@ -6,6 +6,7 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
GLVQ, GLVQ,
KNN, KNN,
@@ -14,7 +15,6 @@ from prototorch.models import (
VisGLVQ2D, VisGLVQ2D,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -27,7 +27,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Prepare the data # Prepare the data
@@ -54,7 +55,9 @@ if __name__ == "__main__":
# Setup trainer for GNG # Setup trainer for GNG
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=1000, accelerator="cpu",
max_epochs=50 if args.fast_dev_run else
1000, # 10 epochs fast dev run reproducible DIV error.
callbacks=[ callbacks=[
es, es,
], ],
@@ -108,8 +111,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

90
pyproject.toml Normal file
View File

@@ -0,0 +1,90 @@
[project]
name = "prototorch-models"
version = "0.7.1"
description = "Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning."
authors = [
{ name = "Jensun Ravichandran", email = "jjensun@gmail.com" },
{ name = "Alexander Engelsberger", email = "engelsbe@hs-mittweida.de" },
]
dependencies = ["lightning>=2.0.0", "prototorch>=0.7.5"]
requires-python = ">=3.8"
readme = "README.md"
license = { text = "MIT" }
classifiers = [
"Development Status :: 2 - Pre-Alpha",
"Environment :: Plugins",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
]
[project.urls]
Homepage = "https://github.com/si-cim/prototorch_models"
Downloads = "https://github.com/si-cim/prototorch_models.git"
[project.optional-dependencies]
dev = ["bumpversion", "pre-commit", "yapf", "toml"]
examples = ["matplotlib", "scikit-learn"]
ci = ["pytest", "pre-commit"]
docs = [
"recommonmark",
"nbsphinx",
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-bibtex",
"sphinxcontrib-katex",
"ipykernel",
]
all = [
"bumpversion",
"pre-commit",
"yapf",
"toml",
"pytest",
"matplotlib",
"scikit-learn",
"recommonmark",
"nbsphinx",
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-bibtex",
"sphinxcontrib-katex",
"ipykernel",
]
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"
[tool.yapf]
based_on_style = "pep8"
spaces_before_comment = 2
split_before_logical_operator = true
[tool.pylint]
disable = ["too-many-arguments", "too-few-public-methods", "fixme"]
[tool.isort]
profile = "hug"
src_paths = ["isort", "test"]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 3
use_parentheses = true
line_length = 79
[tool.mypy]
explicit_package_bases = true
namespace_packages = true

View File

@@ -1,23 +0,0 @@
[yapf]
based_on_style = pep8
spaces_before_comment = 2
split_before_logical_operator = true
[pylint]
disable =
too-many-arguments,
too-few-public-methods,
fixme,
[pycodestyle]
max-line-length = 79
[isort]
profile = hug
src_paths = isort, test
multi_line_output = 3
include_trailing_comma = True
force_grid_wrap = 3
use_parentheses = True
line_length = 79

View File

@@ -1,98 +0,0 @@
"""
######
# # ##### #### ##### #### ##### #### ##### #### # #
# # # # # # # # # # # # # # # # # #
###### # # # # # # # # # # # # # ######
# ##### # # # # # # # # ##### # # #
# # # # # # # # # # # # # # # # #
# # # #### # #### # #### # # #### # #Plugin
ProtoTorch models Plugin Package
"""
from pkg_resources import safe_name
from setuptools import find_namespace_packages, setup
PLUGIN_NAME = "models"
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh:
long_description = fh.read()
INSTALL_REQUIRES = [
"prototorch>=0.7.3",
"pytorch_lightning>=1.6.0",
"torchmetrics",
]
CLI = [
"jsonargparse",
]
DEV = [
"bumpversion",
"pre-commit",
]
DOCS = [
"recommonmark",
"sphinx",
"nbsphinx",
"ipykernel",
"sphinx_rtd_theme",
"sphinxcontrib-katex",
"sphinxcontrib-bibtex",
]
EXAMPLES = [
"matplotlib",
"scikit-learn",
]
TESTS = [
"codecov",
"pytest",
]
ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.5.1",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,
long_description_content_type="text/markdown",
author="Alexander Engelsberger",
author_email="engelsbe@hs-mittweida.de",
url=PROJECT_URL,
download_url=DOWNLOAD_URL,
license="MIT",
python_requires=">=3.7",
install_requires=INSTALL_REQUIRES,
extras_require={
"dev": DEV,
"examples": EXAMPLES,
"tests": TESTS,
"all": ALL,
},
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Environment :: Plugins",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={
"prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}"
},
packages=find_namespace_packages(include=["prototorch.*"]),
zip_safe=False,
)

View File

@@ -36,4 +36,4 @@ from .unsupervised import (
) )
from .vis import * from .vis import *
__version__ = "0.5.1" __version__ = "0.7.1"

View File

@@ -71,7 +71,7 @@ class PrototypeModel(ProtoTorchBolt):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance) distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn) self.distance_layer = LambdaLayer(distance_fn, name="distance_fn")
@property @property
def num_prototypes(self): def num_prototypes(self):
@@ -186,26 +186,37 @@ class SupervisedPrototypeModel(PrototypeModel):
def log_acc(self, distances, targets, tag): def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances) preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) accuracy = torchmetrics.functional.accuracy(
# `.int()` because FloatTensors are assumed to be class probabilities preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(tag, self.log(
tag,
accuracy, accuracy,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
logger=True) logger=True,
)
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
x, targets = batch x, targets = batch
preds = self.predict(x) preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) accuracy = torchmetrics.functional.accuracy(
preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log("test_acc", accuracy) self.log("test_acc", accuracy)
class ProtoTorchMixin(object): class ProtoTorchMixin:
"""All mixins are ProtoTorchMixins.""" """All mixins are ProtoTorchMixins."""
@@ -216,7 +227,7 @@ class NonGradientMixin(ProtoTorchMixin):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.automatic_optimization = False self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx):
raise NotImplementedError raise NotImplementedError

View File

@@ -44,7 +44,7 @@ class CBC(SiameseGLVQ):
probs = self.competition_layer(detections, reasonings) probs = self.competition_layer(detections, reasonings)
return probs return probs
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx):
x, y = batch x, y = batch
y_pred = self(x) y_pred = self(x)
num_classes = self.num_classes num_classes = self.num_classes
@@ -52,17 +52,23 @@ class CBC(SiameseGLVQ):
loss = self.loss(y_pred, y_true).mean() loss = self.loss(y_pred, y_true).mean()
return y_pred, loss return y_pred, loss
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx):
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx) y_pred, train_loss = self.shared_step(batch, batch_idx)
preds = torch.argmax(y_pred, dim=1) preds = torch.argmax(y_pred, dim=1)
accuracy = torchmetrics.functional.accuracy(preds.int(), accuracy = torchmetrics.functional.accuracy(
batch[1].int()) preds.int(),
self.log("train_acc", batch[1].int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(
"train_acc",
accuracy, accuracy,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
logger=True) logger=True,
)
return train_loss return train_loss
def predict(self, x): def predict(self, x):

View File

@@ -39,7 +39,7 @@ def ltangent_distance(x, y, omegas):
:param `torch.tensor` omegas: Three dimensional matrix :param `torch.tensor` omegas: Three dimensional matrix
:rtype: `torch.tensor` :rtype: `torch.tensor`
""" """
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm( p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
omegas, omegas.permute([0, 2, 1])) omegas, omegas.permute([0, 2, 1]))
projected_x = x @ p projected_x = x @ p

View File

@@ -1,13 +1,15 @@
"""Models based on the GLVQ framework.""" """Models based on the GLVQ framework."""
import torch import torch
from numpy.typing import NDArray
from prototorch.core.competitions import wtac from prototorch.core.competitions import wtac
from prototorch.core.distances import ( from prototorch.core.distances import (
ML_omega_distance,
lomega_distance, lomega_distance,
omega_distance, omega_distance,
squared_euclidean_distance, squared_euclidean_distance,
) )
from prototorch.core.initializers import EyeLinearTransformInitializer from prototorch.core.initializers import LLTI, EyeLinearTransformInitializer
from prototorch.core.losses import ( from prototorch.core.losses import (
GLVQLoss, GLVQLoss,
lvq1_loss, lvq1_loss,
@@ -15,7 +17,7 @@ from prototorch.core.losses import (
) )
from prototorch.core.transforms import LinearTransform from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter from torch.nn import Parameter, ParameterList
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization from .extras import ltangent_distance, orthogonalization
@@ -45,36 +47,38 @@ class GLVQ(SupervisedPrototypeModel):
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.register_buffer( self.register_buffer(
"prototype_win_ratios", "prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device)
torch.zeros(self.num_prototypes, device=self.device)) )
def on_train_epoch_start(self): def on_train_epoch_start(self):
self.initialize_prototype_win_ratios() self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances): def log_prototype_win_ratios(self, distances):
batch_size = len(distances) batch_size = len(distances)
prototype_wc = torch.zeros(self.num_prototypes, prototype_wc = torch.zeros(
dtype=torch.long, self.num_prototypes, dtype=torch.long, device=self.device
device=self.device) )
wi, wc = torch.unique(distances.min(dim=-1).indices, wi, wc = torch.unique(
sorted=True, distances.min(dim=-1).indices, sorted=True, return_counts=True
return_counts=True) )
prototype_wc[wi] = wc prototype_wc[wi] = wc
prototype_wr = prototype_wc / batch_size prototype_wr = prototype_wc / batch_size
self.prototype_win_ratios = torch.vstack([ self.prototype_win_ratios = torch.vstack(
[
self.prototype_win_ratios, self.prototype_win_ratios,
prototype_wr, prototype_wr,
]) ]
)
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx):
x, y = batch x, y = batch
out = self.compute_distances(x) out = self.compute_distances(x)
_, plabels = self.proto_layer() _, plabels = self.proto_layer()
loss = self.loss(out, y, plabels) loss = self.loss(out, y, plabels)
return out, loss return out, loss
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx) out, train_loss = self.shared_step(batch, batch_idx)
self.log_prototype_win_ratios(out) self.log_prototype_win_ratios(out)
self.log("train_loss", train_loss) self.log("train_loss", train_loss)
self.log_acc(out, batch[-1], tag="train_acc") self.log_acc(out, batch[-1], tag="train_acc")
@@ -99,10 +103,6 @@ class GLVQ(SupervisedPrototypeModel):
test_loss += batch_loss.item() test_loss += batch_loss.item()
self.log("test_loss", test_loss) self.log("test_loss", test_loss)
# TODO
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
class SiameseGLVQ(GLVQ): class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting. """GLVQ in a Siamese setting.
@@ -113,39 +113,17 @@ class SiameseGLVQ(GLVQ):
""" """
def __init__(self, def __init__(
hparams, self, hparams, backbone=torch.nn.Identity(), both_path_gradients=False, **kwargs
backbone=torch.nn.Identity(), ):
both_path_gradients=False,
**kwargs):
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance) distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone self.backbone = backbone
self.both_path_gradients = both_path_gradients self.both_path_gradients = both_path_gradients
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams["proto_lr"])
# Only add a backbone optimizer if backbone has trainable parameters
bb_params = list(self.backbone.parameters())
if (bb_params):
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
optimizers = [proto_opt, bb_opt]
else:
optimizers = [proto_opt]
if self.lr_scheduler is not None:
schedulers = []
for optimizer in optimizers:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
schedulers.append(scheduler)
return optimizers, schedulers
else:
return optimizers
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)] x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
latent_x = self.backbone(x) latent_x = self.backbone(x)
bb_grad = any([el.requires_grad for el in self.backbone.parameters()]) bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
@@ -199,6 +177,7 @@ class GRLVQ(SiameseGLVQ):
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise. TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
""" """
_relevances: torch.Tensor _relevances: torch.Tensor
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
@@ -209,8 +188,10 @@ class GRLVQ(SiameseGLVQ):
self.register_parameter("_relevances", Parameter(relevances)) self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone # Override the backbone
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances), self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling")
name="relevance scaling")
def _apply_relevances(self, x):
return x @ torch.diag(self._relevances)
@property @property
def relevance_profile(self): def relevance_profile(self):
@@ -231,8 +212,9 @@ class SiameseGMLVQ(SiameseGLVQ):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Override the backbone # Override the backbone
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get(
EyeLinearTransformInitializer()) "omega_initializer", EyeLinearTransformInitializer()
)
self.backbone = LinearTransform( self.backbone = LinearTransform(
self.hparams["input_dim"], self.hparams["input_dim"],
self.hparams["latent_dim"], self.hparams["latent_dim"],
@@ -250,6 +232,49 @@ class SiameseGMLVQ(SiameseGLVQ):
return lam.detach().cpu() return lam.detach().cpu()
class GMLMLVQ(GLVQ):
"""Generalized Multi-Layer Matrix Learning Vector Quantization.
Masks are applied to the omega layers to achieve sparsity and constrain
learning to certain items of each omega.
Implemented as a regular GLVQ network that simply uses a different distance
function. This makes it easier to implement a localized variant.
"""
# Parameters
_omegas: list[torch.Tensor]
masks: list[torch.Tensor]
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
self._masks = ParameterList(
[Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")]
)
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks])
@property
def omega_matrices(self):
return [_omega.detach().cpu() for _omega in self._omegas]
@property
def lambda_matrix(self):
# TODO update to respective lambda calculation rules.
omega = self._omega.detach() # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omegas, self._masks)
return distances
def extra_repr(self):
return f"(omegas): (shapes: {[tuple(_omega.shape) for _omega in self._omegas]})"
class GMLVQ(GLVQ): class GMLVQ(GLVQ):
"""Generalized Matrix Learning Vector Quantization. """Generalized Matrix Learning Vector Quantization.
@@ -266,13 +291,13 @@ class GMLVQ(GLVQ):
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters # Additional parameters
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get(
EyeLinearTransformInitializer()) "omega_initializer", EyeLinearTransformInitializer()
omega = omega_initializer.generate(self.hparams["input_dim"], )
self.hparams["latent_dim"]) omega = omega_initializer.generate(
self.hparams["input_dim"], self.hparams["latent_dim"]
)
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega,
name="omega matrix")
@property @property
def omega_matrix(self): def omega_matrix(self):

View File

@@ -34,7 +34,7 @@ class KNN(SupervisedPrototypeModel):
labels_initializer=LiteralLabelsInitializer(targets)) labels_initializer=LiteralLabelsInitializer(targets))
self.competition_layer = KNNC(k=self.hparams.k) self.competition_layer = KNNC(k=self.hparams.k)
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx):
return 1 # skip training step return 1 # skip training step
def on_train_batch_start(self, train_batch, batch_idx): def on_train_batch_start(self, train_batch, batch_idx):

View File

@@ -13,7 +13,7 @@ from .glvq import GLVQ
class LVQ1(NonGradientMixin, GLVQ): class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1.""" """Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx):
protos, plables = self.proto_layer() protos, plables = self.proto_layer()
x, y = train_batch x, y = train_batch
dis = self.compute_distances(x) dis = self.compute_distances(x)
@@ -43,7 +43,7 @@ class LVQ1(NonGradientMixin, GLVQ):
class LVQ21(NonGradientMixin, GLVQ): class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1.""" """Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx):
protos, plabels = self.proto_layer() protos, plabels = self.proto_layer()
x, y = train_batch x, y = train_batch
@@ -100,7 +100,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
lower_bound = (gamma * f.log()).sum() lower_bound = (gamma * f.log()).sum()
return lower_bound return lower_bound
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx):
protos, plabels = self.proto_layer() protos, plabels = self.proto_layer()
x, y = train_batch x, y = train_batch

View File

@@ -21,7 +21,7 @@ class CELVQ(GLVQ):
# Loss # Loss
self.loss = torch.nn.CrossEntropyLoss() self.loss = torch.nn.CrossEntropyLoss()
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx):
x, y = batch x, y = batch
out = self.compute_distances(x) # [None, num_protos] out = self.compute_distances(x) # [None, num_protos]
_, plabels = self.proto_layer() _, plabels = self.proto_layer()
@@ -63,7 +63,7 @@ class ProbabilisticLVQ(GLVQ):
prediction[confidence < self.rejection_confidence] = -1 prediction[confidence < self.rejection_confidence] = -1
return prediction return prediction
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx):
x, y = batch x, y = batch
out = self.forward(x) out = self.forward(x)
_, plabels = self.proto_layer() _, plabels = self.proto_layer()
@@ -123,7 +123,7 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
self.loss = torch.nn.KLDivLoss() self.loss = torch.nn.KLDivLoss()
# FIXME # FIXME
# def training_step(self, batch, batch_idx, optimizer_idx=None): # def training_step(self, batch, batch_idx):
# x, y = batch # x, y = batch
# y_pred = self(x) # y_pred = self(x)
# batch_loss = self.loss(y_pred, y) # batch_loss = self.loss(y_pred, y)

View File

@@ -63,7 +63,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
strict=False, strict=False,
) )
def training_epoch_end(self, training_step_outputs): def on_training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp( self._sigma = self.hparams.sigma * np.exp(
-self.current_epoch / self.trainer.max_epochs) -self.current_epoch / self.trainer.max_epochs)

View File

@@ -1,195 +1,193 @@
"""prototorch.models test suite.""" """prototorch.models test suite."""
import prototorch as pt import prototorch.models
import pytest
import torch
def test_glvq_model_build(): def test_glvq_model_build():
model = pt.models.GLVQ( model = prototorch.models.GLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_glvq1_model_build(): def test_glvq1_model_build():
model = pt.models.GLVQ1( model = prototorch.models.GLVQ1(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_glvq21_model_build(): def test_glvq21_model_build():
model = pt.models.GLVQ1( model = prototorch.models.GLVQ1(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_gmlvq_model_build(): def test_gmlvq_model_build():
model = pt.models.GMLVQ( model = prototorch.models.GMLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 2, "input_dim": 2,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_grlvq_model_build(): def test_grlvq_model_build():
model = pt.models.GRLVQ( model = prototorch.models.GRLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 2, "input_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_gtlvq_model_build(): def test_gtlvq_model_build():
model = pt.models.GTLVQ( model = prototorch.models.GTLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 4, "input_dim": 4,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_lgmlvq_model_build(): def test_lgmlvq_model_build():
model = pt.models.LGMLVQ( model = prototorch.models.LGMLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 4, "input_dim": 4,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_image_glvq_model_build(): def test_image_glvq_model_build():
model = pt.models.ImageGLVQ( model = prototorch.models.ImageGLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(16), prototypes_initializer=prototorch.initializers.RNCI(16),
) )
def test_image_gmlvq_model_build(): def test_image_gmlvq_model_build():
model = pt.models.ImageGMLVQ( model = prototorch.models.ImageGMLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 16, "input_dim": 16,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(16), prototypes_initializer=prototorch.initializers.RNCI(16),
) )
def test_image_gtlvq_model_build(): def test_image_gtlvq_model_build():
model = pt.models.ImageGMLVQ( model = prototorch.models.ImageGMLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 16, "input_dim": 16,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(16), prototypes_initializer=prototorch.initializers.RNCI(16),
) )
def test_siamese_glvq_model_build(): def test_siamese_glvq_model_build():
model = pt.models.SiameseGLVQ( model = prototorch.models.SiameseGLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(4), prototypes_initializer=prototorch.initializers.RNCI(4),
) )
def test_siamese_gmlvq_model_build(): def test_siamese_gmlvq_model_build():
model = pt.models.SiameseGMLVQ( model = prototorch.models.SiameseGMLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 4, "input_dim": 4,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(4), prototypes_initializer=prototorch.initializers.RNCI(4),
) )
def test_siamese_gtlvq_model_build(): def test_siamese_gtlvq_model_build():
model = pt.models.SiameseGTLVQ( model = prototorch.models.SiameseGTLVQ(
{ {
"distribution": (3, 2), "distribution": (3, 2),
"input_dim": 4, "input_dim": 4,
"latent_dim": 2, "latent_dim": 2,
}, },
prototypes_initializer=pt.initializers.RNCI(4), prototypes_initializer=prototorch.initializers.RNCI(4),
) )
def test_knn_model_build(): def test_knn_model_build():
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = prototorch.datasets.Iris(dims=[0, 2])
model = pt.models.KNN(dict(k=3), data=train_ds) model = prototorch.models.KNN(dict(k=3), data=train_ds)
def test_lvq1_model_build(): def test_lvq1_model_build():
model = pt.models.LVQ1( model = prototorch.models.LVQ1(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_lvq21_model_build(): def test_lvq21_model_build():
model = pt.models.LVQ21( model = prototorch.models.LVQ21(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_median_lvq_model_build(): def test_median_lvq_model_build():
model = pt.models.MedianLVQ( model = prototorch.models.MedianLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_celvq_model_build(): def test_celvq_model_build():
model = pt.models.CELVQ( model = prototorch.models.CELVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_rslvq_model_build(): def test_rslvq_model_build():
model = pt.models.RSLVQ( model = prototorch.models.RSLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_slvq_model_build(): def test_slvq_model_build():
model = pt.models.SLVQ( model = prototorch.models.SLVQ(
{"distribution": (3, 2)}, {"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_growing_neural_gas_model_build(): def test_growing_neural_gas_model_build():
model = pt.models.GrowingNeuralGas( model = prototorch.models.GrowingNeuralGas(
{"num_prototypes": 5}, {"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_kohonen_som_model_build(): def test_kohonen_som_model_build():
model = pt.models.KohonenSOM( model = prototorch.models.KohonenSOM(
{"shape": (3, 2)}, {"shape": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )
def test_neural_gas_model_build(): def test_neural_gas_model_build():
model = pt.models.NeuralGas( model = prototorch.models.NeuralGas(
{"num_prototypes": 5}, {"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2), prototypes_initializer=prototorch.initializers.RNCI(2),
) )