10 Commits

Author SHA1 Message Date
Alexander Engelsberger
1b5093627e build: bump version 0.5.4 → 0.6.0 2023-06-20 18:50:03 +02:00
Alexander Engelsberger
497da90f9c chore: small changes to configuration 2023-06-20 18:49:57 +02:00
Alexander Engelsberger
2a665e220f fix: use multiclass accuracy by default 2023-06-20 18:30:18 +02:00
Alexander Engelsberger
4cd6aee330 chore: replace config by pyproject.toml 2023-06-20 18:30:05 +02:00
Alexander Engelsberger
634ef86a2c fix: example test fixed 2023-06-20 17:42:36 +02:00
Alexander Engelsberger
72e9587a10 fix: remove removed CLI syntax from examples 2023-06-20 17:30:21 +02:00
Alexander Engelsberger
f5e1edf31f ci: upgrade workflows 2023-06-20 16:39:13 +02:00
Alexander Engelsberger
5e5675d12e ci: upgrade pre-commit config 2023-06-20 16:37:11 +02:00
Alexander Engelsberger
16f410e809 fix: style fixes 2023-03-09 15:59:49 +01:00
Alexander Engelsberger
46dfb82371 Fix: saving GMLVQ and GRLVQ fixed 2023-03-09 15:50:13 +01:00
35 changed files with 460 additions and 354 deletions

View File

@@ -1,12 +1,12 @@
[bumpversion]
current_version = 0.5.4
current_version = 0.6.0
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = {major}.{minor}.{patch}
message = build: bump version {current_version} → {new_version}
[bumpversion:file:setup.py]
[bumpversion:file:pyproject.toml]
[bumpversion:file:./prototorch/models/__init__.py]

View File

@@ -6,20 +6,20 @@ name: examples
on:
push:
paths:
- 'examples/**.py'
- "examples/**.py"
jobs:
cpu:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Run examples
run: |
./tests/test_examples.sh examples/
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Run examples
run: |
./tests/test_examples.sh examples/

View File

@@ -6,70 +6,70 @@ name: tests
on:
push:
pull_request:
branches: [ master ]
branches: [master]
jobs:
style:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- uses: pre-commit/action@v2.0.3
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- uses: pre-commit/action@v3.0.0
compatibility:
needs: style
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
os: [ubuntu-latest, windows-latest]
exclude:
- os: windows-latest
python-version: "3.7"
- os: windows-latest
python-version: "3.8"
- os: windows-latest
python-version: "3.9"
- os: windows-latest
python-version: "3.8"
- os: windows-latest
python-version: "3.9"
- os: windows-latest
python-version: "3.10"
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: |
pytest
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: |
pytest
publish_pypi:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs: compatibility
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
pip install wheel
- name: Build package
run: python setup.py sdist bdist_wheel
- name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
pip install wheel
- name: Build package
run: python setup.py sdist bdist_wheel
- name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -2,52 +2,53 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-ast
- id: check-case-conflict
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-ast
- id: check-case-conflict
- repo: https://github.com/myint/autoflake
rev: v2.0.1
hooks:
- id: autoflake
- repo: https://github.com/myint/autoflake
rev: v2.1.1
hooks:
- id: autoflake
- repo: http://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- repo: http://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
hooks:
- id: mypy
files: prototorch
additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:
- id: mypy
files: prototorch
additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0
hooks:
- id: yapf
additional_dependencies: ["toml"]
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: python-use-type-annotations
- id: python-no-log-warn
- id: python-check-blanket-noqa
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: python-use-type-annotations
- id: python-no-log-warn
- id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
- repo: https://github.com/asottile/pyupgrade
rev: v3.7.0
hooks:
- id: pyupgrade
- repo: https://github.com/si-cim/gitlint
rev: v0.15.2-unofficial
hooks:
- id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]
- repo: https://github.com/si-cim/gitlint
rev: v0.15.2-unofficial
hooks:
- id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "0.5.4"
release = "0.6.0"
# -- General configuration ---------------------------------------------------

View File

@@ -5,8 +5,8 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import CBC, VisCBC2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -19,7 +19,8 @@ if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -53,8 +54,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],

View File

@@ -7,13 +7,13 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
CELVQ,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -26,7 +26,8 @@ if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -83,8 +84,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
pruning,

View File

@@ -7,8 +7,8 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
@@ -21,7 +21,8 @@ if __name__ == "__main__":
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -55,8 +56,10 @@ if __name__ == "__main__":
vis = VisGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],

View File

@@ -6,8 +6,8 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GMLVQ, VisGMLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
@@ -22,7 +22,8 @@ if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -59,8 +60,10 @@ if __name__ == "__main__":
vis = VisGMLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
@@ -71,3 +74,5 @@ if __name__ == "__main__":
# Training loop
trainer.fit(model, train_loader)
torch.save(model, "iris.pth")

View File

@@ -6,13 +6,13 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
ImageGMLVQ,
PruneLoserPrototypes,
VisImgComp,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from torchvision import transforms
@@ -26,7 +26,8 @@ if __name__ == "__main__":
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -96,8 +97,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
pruning,

View File

@@ -6,13 +6,13 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
GMLVQ,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -25,7 +25,8 @@ if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -78,8 +79,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
es,

View File

@@ -7,8 +7,8 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GrowingNeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Reproducibility
@@ -51,8 +52,10 @@ if __name__ == "__main__":
vis = VisNG2D(data=train_loader)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],

77
examples/grlvq_iris.py Normal file
View File

@@ -0,0 +1,77 @@
"""GMLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GRLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris([0, 1])
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
input_dim=2,
distribution={
"num_classes": 3,
"per_class": 2
},
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = GRLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisSiameseGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],
max_epochs=5,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)
torch.save(model, "iris.pth")

View File

@@ -6,13 +6,13 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
ImageGTLVQ,
PruneLoserPrototypes,
VisImgComp,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from torchvision import transforms
@@ -27,7 +27,8 @@ if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -100,8 +101,10 @@ if __name__ == "__main__":
# Setup trainer
# using GPUs here is strongly recommended!
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
pruning,

View File

@@ -7,9 +7,9 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GTLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Reproducibility
@@ -61,8 +62,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
es,

View File

@@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -59,8 +60,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
max_epochs=1,
callbacks=[
vis,

View File

@@ -7,10 +7,10 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from matplotlib import pyplot as plt
from prototorch.models import KohonenSOM
from prototorch.utils.colors import hex_to_rgb
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader, TensorDataset
@@ -58,7 +58,8 @@ class Vis2DColorSOM(pl.Callback):
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Reproducibility
@@ -104,8 +105,10 @@ if __name__ == "__main__":
vis = Vis2DColorSOM(data=data)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
max_epochs=500,
callbacks=[
vis,

View File

@@ -7,9 +7,9 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import LGMLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Reproducibility
@@ -62,8 +63,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
es,

View File

@@ -6,12 +6,12 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
LVQMLN,
PruneLoserPrototypes,
VisSiameseGLVQ2D,
)
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -39,7 +39,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -88,8 +89,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
pruning,

View File

@@ -6,9 +6,9 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import MedianLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -20,7 +20,8 @@ if __name__ == "__main__":
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -53,8 +54,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
es,

View File

@@ -6,8 +6,8 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import NeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
@@ -23,7 +23,8 @@ if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Prepare and pre-process the dataset
@@ -60,8 +61,10 @@ if __name__ == "__main__":
vis = VisNG2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],

View File

@@ -6,8 +6,8 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import RSLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -17,7 +17,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Reproducibility
@@ -54,8 +55,10 @@ if __name__ == "__main__":
vis = VisGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],

View File

@@ -6,8 +6,8 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -50,8 +51,7 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
distribution=[1, 2, 3],
proto_lr=0.01,
bb_lr=0.01,
lr=0.01,
)
# Initialize the backbone
@@ -69,8 +69,10 @@ if __name__ == "__main__":
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],

View File

@@ -6,8 +6,8 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
@@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Dataset
@@ -50,8 +51,7 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
distribution=[1, 2, 3],
proto_lr=0.01,
bb_lr=0.01,
lr=0.01,
input_dim=2,
latent_dim=1,
)
@@ -71,8 +71,10 @@ if __name__ == "__main__":
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
],

View File

@@ -6,6 +6,7 @@ import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import (
GLVQ,
KNN,
@@ -14,7 +15,6 @@ from prototorch.models import (
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
@@ -27,7 +27,8 @@ if __name__ == "__main__":
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args()
# Prepare the data
@@ -54,7 +55,9 @@ if __name__ == "__main__":
# Setup trainer for GNG
trainer = pl.Trainer(
max_epochs=1000,
accelerator="cpu",
max_epochs=50 if args.fast_dev_run else
1000, # 10 epochs fast dev run reproducible DIV error.
callbacks=[
es,
],
@@ -108,8 +111,10 @@ if __name__ == "__main__":
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[
vis,
pruning,

BIN
glvq_iris.ckpt Normal file

Binary file not shown.

BIN
iris.pth Normal file

Binary file not shown.

View File

@@ -36,4 +36,4 @@ from .unsupervised import (
)
from .vis import *
__version__ = "0.5.4"
__version__ = "0.6.0"

View File

@@ -2,6 +2,7 @@
import logging
import prototorch
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
@@ -71,7 +72,7 @@ class PrototypeModel(ProtoTorchBolt):
super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
self.distance_layer = LambdaLayer(distance_fn, name="distance_fn")
@property
def num_prototypes(self):
@@ -186,21 +187,32 @@ class SupervisedPrototypeModel(PrototypeModel):
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
accuracy = torchmetrics.functional.accuracy(
preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
self.log(
tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
def test_step(self, batch, batch_idx):
x, targets = batch
preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
accuracy = torchmetrics.functional.accuracy(
preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log("test_acc", accuracy)

View File

@@ -55,14 +55,20 @@ class CBC(SiameseGLVQ):
def training_step(self, batch, batch_idx, optimizer_idx=None):
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
preds = torch.argmax(y_pred, dim=1)
accuracy = torchmetrics.functional.accuracy(preds.int(),
batch[1].int())
self.log("train_acc",
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
accuracy = torchmetrics.functional.accuracy(
preds.int(),
batch[1].int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(
"train_acc",
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return train_loss
def predict(self, x):

View File

@@ -123,26 +123,6 @@ class SiameseGLVQ(GLVQ):
self.backbone = backbone
self.both_path_gradients = both_path_gradients
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams["proto_lr"])
# Only add a backbone optimizer if backbone has trainable parameters
bb_params = list(self.backbone.parameters())
if (bb_params):
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
optimizers = [proto_opt, bb_opt]
else:
optimizers = [proto_opt]
if self.lr_scheduler is not None:
schedulers = []
for optimizer in optimizers:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
schedulers.append(scheduler)
return optimizers, schedulers
else:
return optimizers
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
@@ -209,9 +189,12 @@ class GRLVQ(SiameseGLVQ):
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
self.backbone = LambdaLayer(self._apply_relevances,
name="relevance scaling")
def _apply_relevances(self, x):
return x @ torch.diag(self._relevances)
@property
def relevance_profile(self):
return self._relevances.detach().cpu()
@@ -271,8 +254,6 @@ class GMLVQ(GLVQ):
omega = omega_initializer.generate(self.hparams["input_dim"],
self.hparams["latent_dim"])
self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega,
name="omega matrix")
@property
def omega_matrix(self):

View File

@@ -63,7 +63,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
strict=False,
)
def training_epoch_end(self, training_step_outputs):
def on_training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp(
-self.current_epoch / self.trainer.max_epochs)

96
pyproject.toml Normal file
View File

@@ -0,0 +1,96 @@
[project]
name = "prototorch-models"
version = "0.6.0"
description = "Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning."
authors = [
{ name = "Jensun Ravichandran", email = "jjensun@gmail.com" },
{ name = "Alexander Engelsberger", email = "engelsbe@hs-mittweida.de" },
]
dependencies = ["lightning>=2.0.0", "prototorch>=0.7.5"]
requires-python = ">=3.8"
readme = "README.md"
license = { text = "MIT" }
classifiers = [
"Development Status :: 2 - Pre-Alpha",
"Environment :: Plugins",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
]
[project.urls]
Homepage = "https://github.com/si-cim/prototorch_models"
Downloads = "https://github.com/si-cim/prototorch_models.git"
[project.optional-dependencies]
dev = ["bumpversion", "pre-commit", "yapf", "toml"]
examples = ["matplotlib", "scikit-learn"]
ci = ["pytest", "pre-commit"]
docs = [
"recommonmark",
"nbsphinx",
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-bibtex",
"sphinxcontrib-katex",
"ipykernel",
]
all = [
"bumpversion",
"pre-commit",
"yapf",
"toml",
"pytest",
"matplotlib",
"scikit-learn",
"recommonmark",
"nbsphinx",
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-bibtex",
"sphinxcontrib-katex",
"ipykernel",
]
[project.entry-points."prototorch.plugins"]
models = "prototorch.models"
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"
[tool.yapf]
based_on_style = "pep8"
spaces_before_comment = 2
split_before_logical_operator = true
[tool.pylint]
disable = ["too-many-arguments", "too-few-public-methods", "fixme"]
[tool.isort]
profile = "hug"
src_paths = ["isort", "test"]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 3
use_parentheses = true
line_length = 79
[tool.mypy]
explicit_package_bases = true
namespace_packages = true
[tool.setuptools]
py-modules = ["prototorch"]

View File

@@ -1,27 +0,0 @@
[yapf]
based_on_style = pep8
spaces_before_comment = 2
split_before_logical_operator = true
[pylint]
disable =
too-many-arguments,
too-few-public-methods,
fixme,
[pycodestyle]
max-line-length = 79
[isort]
profile = hug
src_paths = isort, test
multi_line_output = 3
include_trailing_comma = True
force_grid_wrap = 3
use_parentheses = True
line_length = 79
[mypy]
explicit_package_bases = True
namespace_packages = True

View File

@@ -1,99 +0,0 @@
"""
######
# # ##### #### ##### #### ##### #### ##### #### # #
# # # # # # # # # # # # # # # # # #
###### # # # # # # # # # # # # # ######
# ##### # # # # # # # # ##### # # #
# # # # # # # # # # # # # # # # #
# # # #### # #### # #### # # #### # #Plugin
ProtoTorch models Plugin Package
"""
from pkg_resources import safe_name
from setuptools import find_namespace_packages, setup
PLUGIN_NAME = "models"
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md") as fh:
long_description = fh.read()
INSTALL_REQUIRES = [
"prototorch>=0.7.3",
"pytorch_lightning>=1.6.0",
"torchmetrics<0.10.0",
"protobuf<3.20.0",
]
CLI = [
"jsonargparse",
]
DEV = [
"bumpversion",
"pre-commit",
]
DOCS = [
"recommonmark",
"sphinx",
"nbsphinx",
"ipykernel",
"sphinx_rtd_theme",
"sphinxcontrib-katex",
"sphinxcontrib-bibtex",
]
EXAMPLES = [
"matplotlib",
"scikit-learn",
]
TESTS = [
"codecov",
"pytest",
]
ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.5.4",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,
long_description_content_type="text/markdown",
author="Alexander Engelsberger",
author_email="engelsbe@hs-mittweida.de",
url=PROJECT_URL,
download_url=DOWNLOAD_URL,
license="MIT",
python_requires=">=3.7",
install_requires=INSTALL_REQUIRES,
extras_require={
"dev": DEV,
"examples": EXAMPLES,
"tests": TESTS,
"all": ALL,
},
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Environment :: Plugins",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={
"prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}"
},
packages=find_namespace_packages(include=["prototorch.*"]),
zip_safe=False,
)