Compare commits
25 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
d4bf6dbbe9 | ||
|
c99fdb436c | ||
|
28ac5f5ed9 | ||
|
b7f510a9fe | ||
|
781ef93b06 | ||
|
072e61b3cd | ||
|
71167a8f77 | ||
|
60990f42d2 | ||
|
1e83c439f7 | ||
|
cbbbbeda98 | ||
|
1b5093627e | ||
|
497da90f9c | ||
|
2a665e220f | ||
|
4cd6aee330 | ||
|
634ef86a2c | ||
|
72e9587a10 | ||
|
f5e1edf31f | ||
|
5e5675d12e | ||
|
16f410e809 | ||
|
46dfb82371 | ||
|
87fa3f0729 | ||
|
08db94d507 | ||
|
8ecf9948b2 | ||
|
c5f0b86114 | ||
|
7506614ada |
@ -1,13 +1,13 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.5.2
|
current_version = 0.7.1
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
serialize = {major}.{minor}.{patch}
|
serialize = {major}.{minor}.{patch}
|
||||||
message = build: bump version {current_version} → {new_version}
|
message = build: bump version {current_version} → {new_version}
|
||||||
|
|
||||||
[bumpversion:file:setup.py]
|
[bumpversion:file:pyproject.toml]
|
||||||
|
|
||||||
[bumpversion:file:./prototorch/models/__init__.py]
|
[bumpversion:file:./src/prototorch/models/__init__.py]
|
||||||
|
|
||||||
[bumpversion:file:./docs/source/conf.py]
|
[bumpversion:file:./docs/source/conf.py]
|
||||||
|
10
.github/workflows/examples.yml
vendored
10
.github/workflows/examples.yml
vendored
@ -6,16 +6,16 @@ name: examples
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
paths:
|
paths:
|
||||||
- 'examples/**.py'
|
- "examples/**.py"
|
||||||
jobs:
|
jobs:
|
||||||
cpu:
|
cpu:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python 3.11
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.11"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
32
.github/workflows/pythonapp.yml
vendored
32
.github/workflows/pythonapp.yml
vendored
@ -6,42 +6,42 @@ name: tests
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master ]
|
branches: [master]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
style:
|
style:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python 3.11
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.11"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .[all]
|
pip install .[all]
|
||||||
- uses: pre-commit/action@v2.0.3
|
- uses: pre-commit/action@v3.0.0
|
||||||
compatibility:
|
compatibility:
|
||||||
needs: style
|
needs: style
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||||
os: [ubuntu-latest, windows-latest]
|
os: [ubuntu-latest, windows-latest]
|
||||||
exclude:
|
exclude:
|
||||||
- os: windows-latest
|
|
||||||
python-version: "3.7"
|
|
||||||
- os: windows-latest
|
- os: windows-latest
|
||||||
python-version: "3.8"
|
python-version: "3.8"
|
||||||
- os: windows-latest
|
- os: windows-latest
|
||||||
python-version: "3.9"
|
python-version: "3.9"
|
||||||
|
- os: windows-latest
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
@ -56,18 +56,18 @@ jobs:
|
|||||||
needs: compatibility
|
needs: compatibility
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python 3.11
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.11"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .[all]
|
pip install .[all]
|
||||||
pip install wheel
|
pip install build
|
||||||
- name: Build package
|
- name: Build package
|
||||||
run: python setup.py sdist bdist_wheel
|
run: python -m build . -C verbose
|
||||||
- name: Publish a Python distribution to PyPI
|
- name: Publish a Python distribution to PyPI
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
with:
|
with:
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
# See https://pre-commit.com/hooks.html for more hooks
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.2.0
|
rev: v4.4.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@ -12,41 +12,42 @@ repos:
|
|||||||
- id: check-ast
|
- id: check-ast
|
||||||
- id: check-case-conflict
|
- id: check-case-conflict
|
||||||
|
|
||||||
- repo: https://github.com/myint/autoflake
|
- repo: https://github.com/myint/autoflake
|
||||||
rev: v1.4
|
rev: v2.1.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: autoflake
|
- id: autoflake
|
||||||
|
|
||||||
- repo: http://github.com/PyCQA/isort
|
- repo: http://github.com/PyCQA/isort
|
||||||
rev: 5.10.1
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v0.950
|
rev: v1.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
files: prototorch
|
files: prototorch
|
||||||
additional_dependencies: [types-pkg_resources]
|
additional_dependencies: [types-pkg_resources]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||||
rev: v0.32.0
|
rev: v0.32.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
|
additional_dependencies: ["toml"]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||||
rev: v1.9.0
|
rev: v1.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: python-use-type-annotations
|
- id: python-use-type-annotations
|
||||||
- id: python-no-log-warn
|
- id: python-no-log-warn
|
||||||
- id: python-check-blanket-noqa
|
- id: python-check-blanket-noqa
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v2.32.1
|
rev: v3.7.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
|
||||||
- repo: https://github.com/si-cim/gitlint
|
- repo: https://github.com/si-cim/gitlint
|
||||||
rev: v0.15.2-unofficial
|
rev: v0.15.2-unofficial
|
||||||
hooks:
|
hooks:
|
||||||
- id: gitlint
|
- id: gitlint
|
||||||
|
@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.5.2"
|
release = "0.7.1"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
"""CBC example using the Iris dataset."""
|
"""CBC example using the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import CBC, VisCBC2D
|
from prototorch.models import CBC, VisCBC2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -19,7 +18,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -53,8 +53,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -7,13 +7,13 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
CELVQ,
|
CELVQ,
|
||||||
PruneLoserPrototypes,
|
PruneLoserPrototypes,
|
||||||
VisGLVQ2D,
|
VisGLVQ2D,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -26,7 +26,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -83,8 +84,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
@ -7,8 +7,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import GLVQ, VisGLVQ2D
|
from prototorch.models import GLVQ, VisGLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -21,7 +21,8 @@ if __name__ == "__main__":
|
|||||||
seed_everything(seed=4)
|
seed_everything(seed=4)
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -55,8 +56,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisGLVQ2D(data=train_ds)
|
vis = VisGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import GMLVQ, VisGMLVQ2D
|
from prototorch.models import GMLVQ, VisGMLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -22,7 +22,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -59,8 +60,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisGMLVQ2D(data=train_ds)
|
vis = VisGMLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
@ -71,3 +74,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
torch.save(model, "iris.pth")
|
||||||
|
@ -6,13 +6,13 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
ImageGMLVQ,
|
ImageGMLVQ,
|
||||||
PruneLoserPrototypes,
|
PruneLoserPrototypes,
|
||||||
VisImgComp,
|
VisImgComp,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
@ -26,7 +26,8 @@ if __name__ == "__main__":
|
|||||||
seed_everything(seed=4)
|
seed_everything(seed=4)
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -96,8 +97,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
@ -6,13 +6,13 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
GMLVQ,
|
GMLVQ,
|
||||||
PruneLoserPrototypes,
|
PruneLoserPrototypes,
|
||||||
VisGLVQ2D,
|
VisGLVQ2D,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -25,7 +25,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -78,8 +79,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
es,
|
es,
|
||||||
|
@ -7,8 +7,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import GrowingNeuralGas, VisNG2D
|
from prototorch.models import GrowingNeuralGas, VisNG2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
@ -51,8 +52,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisNG2D(data=train_loader)
|
vis = VisNG2D(data=train_loader)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
77
examples/grlvq_iris.py
Normal file
77
examples/grlvq_iris.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
"""GMLVQ example using the Iris dataset."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
|
from prototorch.models import GRLVQ, VisSiameseGLVQ2D
|
||||||
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
seed_everything(seed=4)
|
||||||
|
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Iris([0, 1])
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = DataLoader(train_ds, batch_size=64)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(
|
||||||
|
input_dim=2,
|
||||||
|
distribution={
|
||||||
|
"num_classes": 3,
|
||||||
|
"per_class": 2
|
||||||
|
},
|
||||||
|
proto_lr=0.01,
|
||||||
|
bb_lr=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = GRLVQ(
|
||||||
|
hparams,
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
|
lr_scheduler=ExponentialLR,
|
||||||
|
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = VisSiameseGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
],
|
||||||
|
max_epochs=5,
|
||||||
|
log_every_n_steps=1,
|
||||||
|
detect_anomaly=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
torch.save(model, "iris.pth")
|
@ -6,13 +6,13 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
ImageGTLVQ,
|
ImageGTLVQ,
|
||||||
PruneLoserPrototypes,
|
PruneLoserPrototypes,
|
||||||
VisImgComp,
|
VisImgComp,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
@ -27,7 +27,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -100,8 +101,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
# using GPUs here is strongly recommended!
|
# using GPUs here is strongly recommended!
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
@ -7,9 +7,9 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import GTLVQ, VisGLVQ2D
|
from prototorch.models import GTLVQ, VisGLVQ2D
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
@ -61,8 +62,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
es,
|
es,
|
||||||
|
@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -59,8 +60,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
max_epochs=1,
|
max_epochs=1,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
|
@ -7,10 +7,10 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from prototorch.models import KohonenSOM
|
from prototorch.models import KohonenSOM
|
||||||
from prototorch.utils.colors import hex_to_rgb
|
from prototorch.utils.colors import hex_to_rgb
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
@ -58,7 +58,8 @@ class Vis2DColorSOM(pl.Callback):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
@ -104,8 +105,10 @@ if __name__ == "__main__":
|
|||||||
vis = Vis2DColorSOM(data=data)
|
vis = Vis2DColorSOM(data=data)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
max_epochs=500,
|
max_epochs=500,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
|
@ -7,9 +7,9 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import LGMLVQ, VisGLVQ2D
|
from prototorch.models import LGMLVQ, VisGLVQ2D
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
@ -62,8 +63,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
es,
|
es,
|
||||||
|
@ -6,12 +6,12 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
LVQMLN,
|
LVQMLN,
|
||||||
PruneLoserPrototypes,
|
PruneLoserPrototypes,
|
||||||
VisSiameseGLVQ2D,
|
VisSiameseGLVQ2D,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -39,7 +39,8 @@ class Backbone(torch.nn.Module):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -88,8 +89,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
@ -6,9 +6,9 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import MedianLVQ, VisGLVQ2D
|
from prototorch.models import MedianLVQ, VisGLVQ2D
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -20,7 +20,8 @@ if __name__ == "__main__":
|
|||||||
seed_everything(seed=4)
|
seed_everything(seed=4)
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -53,8 +54,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
es,
|
es,
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import NeuralGas, VisNG2D
|
from prototorch.models import NeuralGas, VisNG2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
@ -23,7 +23,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Prepare and pre-process the dataset
|
# Prepare and pre-process the dataset
|
||||||
@ -60,8 +61,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisNG2D(data=train_ds)
|
vis = VisNG2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import RSLVQ, VisGLVQ2D
|
from prototorch.models import RSLVQ, VisGLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -17,7 +17,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
@ -54,8 +55,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisGLVQ2D(data=train_ds)
|
vis = VisGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
|
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -50,8 +51,7 @@ if __name__ == "__main__":
|
|||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[1, 2, 3],
|
distribution=[1, 2, 3],
|
||||||
proto_lr=0.01,
|
lr=0.01,
|
||||||
bb_lr=0.01,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the backbone
|
# Initialize the backbone
|
||||||
@ -69,8 +69,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
|
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -50,8 +51,7 @@ if __name__ == "__main__":
|
|||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[1, 2, 3],
|
distribution=[1, 2, 3],
|
||||||
proto_lr=0.01,
|
lr=0.01,
|
||||||
bb_lr=0.01,
|
|
||||||
input_dim=2,
|
input_dim=2,
|
||||||
latent_dim=1,
|
latent_dim=1,
|
||||||
)
|
)
|
||||||
@ -71,8 +71,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,6 +6,7 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
GLVQ,
|
GLVQ,
|
||||||
KNN,
|
KNN,
|
||||||
@ -14,7 +15,6 @@ from prototorch.models import (
|
|||||||
VisGLVQ2D,
|
VisGLVQ2D,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -27,7 +27,8 @@ if __name__ == "__main__":
|
|||||||
seed_everything(seed=4)
|
seed_everything(seed=4)
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
@ -54,7 +55,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Setup trainer for GNG
|
# Setup trainer for GNG
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=1000,
|
accelerator="cpu",
|
||||||
|
max_epochs=50 if args.fast_dev_run else
|
||||||
|
1000, # 10 epochs fast dev run reproducible DIV error.
|
||||||
callbacks=[
|
callbacks=[
|
||||||
es,
|
es,
|
||||||
],
|
],
|
||||||
@ -108,8 +111,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
90
pyproject.toml
Normal file
90
pyproject.toml
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
|
||||||
|
[project]
|
||||||
|
name = "prototorch-models"
|
||||||
|
version = "0.7.1"
|
||||||
|
description = "Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning."
|
||||||
|
authors = [
|
||||||
|
{ name = "Jensun Ravichandran", email = "jjensun@gmail.com" },
|
||||||
|
{ name = "Alexander Engelsberger", email = "engelsbe@hs-mittweida.de" },
|
||||||
|
]
|
||||||
|
dependencies = ["lightning>=2.0.0", "prototorch>=0.7.5"]
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
readme = "README.md"
|
||||||
|
license = { text = "MIT" }
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 2 - Pre-Alpha",
|
||||||
|
"Environment :: Plugins",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Intended Audience :: Education",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Natural Language :: English",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Software Development :: Libraries",
|
||||||
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/si-cim/prototorch_models"
|
||||||
|
Downloads = "https://github.com/si-cim/prototorch_models.git"
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = ["bumpversion", "pre-commit", "yapf", "toml"]
|
||||||
|
examples = ["matplotlib", "scikit-learn"]
|
||||||
|
ci = ["pytest", "pre-commit"]
|
||||||
|
docs = [
|
||||||
|
"recommonmark",
|
||||||
|
"nbsphinx",
|
||||||
|
"sphinx",
|
||||||
|
"sphinx_rtd_theme",
|
||||||
|
"sphinxcontrib-bibtex",
|
||||||
|
"sphinxcontrib-katex",
|
||||||
|
"ipykernel",
|
||||||
|
]
|
||||||
|
all = [
|
||||||
|
"bumpversion",
|
||||||
|
"pre-commit",
|
||||||
|
"yapf",
|
||||||
|
"toml",
|
||||||
|
"pytest",
|
||||||
|
"matplotlib",
|
||||||
|
"scikit-learn",
|
||||||
|
"recommonmark",
|
||||||
|
"nbsphinx",
|
||||||
|
"sphinx",
|
||||||
|
"sphinx_rtd_theme",
|
||||||
|
"sphinxcontrib-bibtex",
|
||||||
|
"sphinxcontrib-katex",
|
||||||
|
"ipykernel",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.yapf]
|
||||||
|
based_on_style = "pep8"
|
||||||
|
spaces_before_comment = 2
|
||||||
|
split_before_logical_operator = true
|
||||||
|
|
||||||
|
[tool.pylint]
|
||||||
|
disable = ["too-many-arguments", "too-few-public-methods", "fixme"]
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "hug"
|
||||||
|
src_paths = ["isort", "test"]
|
||||||
|
multi_line_output = 3
|
||||||
|
include_trailing_comma = true
|
||||||
|
force_grid_wrap = 3
|
||||||
|
use_parentheses = true
|
||||||
|
line_length = 79
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
explicit_package_bases = true
|
||||||
|
namespace_packages = true
|
23
setup.cfg
23
setup.cfg
@ -1,23 +0,0 @@
|
|||||||
[yapf]
|
|
||||||
based_on_style = pep8
|
|
||||||
spaces_before_comment = 2
|
|
||||||
split_before_logical_operator = true
|
|
||||||
|
|
||||||
[pylint]
|
|
||||||
disable =
|
|
||||||
too-many-arguments,
|
|
||||||
too-few-public-methods,
|
|
||||||
fixme,
|
|
||||||
|
|
||||||
|
|
||||||
[pycodestyle]
|
|
||||||
max-line-length = 79
|
|
||||||
|
|
||||||
[isort]
|
|
||||||
profile = hug
|
|
||||||
src_paths = isort, test
|
|
||||||
multi_line_output = 3
|
|
||||||
include_trailing_comma = True
|
|
||||||
force_grid_wrap = 3
|
|
||||||
use_parentheses = True
|
|
||||||
line_length = 79
|
|
99
setup.py
99
setup.py
@ -1,99 +0,0 @@
|
|||||||
"""
|
|
||||||
|
|
||||||
######
|
|
||||||
# # ##### #### ##### #### ##### #### ##### #### # #
|
|
||||||
# # # # # # # # # # # # # # # # # #
|
|
||||||
###### # # # # # # # # # # # # # ######
|
|
||||||
# ##### # # # # # # # # ##### # # #
|
|
||||||
# # # # # # # # # # # # # # # # #
|
|
||||||
# # # #### # #### # #### # # #### # #Plugin
|
|
||||||
|
|
||||||
ProtoTorch models Plugin Package
|
|
||||||
"""
|
|
||||||
from pkg_resources import safe_name
|
|
||||||
from setuptools import find_namespace_packages, setup
|
|
||||||
|
|
||||||
PLUGIN_NAME = "models"
|
|
||||||
|
|
||||||
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
|
||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
|
||||||
|
|
||||||
with open("README.md", "r") as fh:
|
|
||||||
long_description = fh.read()
|
|
||||||
|
|
||||||
INSTALL_REQUIRES = [
|
|
||||||
"prototorch>=0.7.3",
|
|
||||||
"pytorch_lightning>=1.6.0",
|
|
||||||
"torchmetrics",
|
|
||||||
"protobuf<3.20.0",
|
|
||||||
]
|
|
||||||
CLI = [
|
|
||||||
"jsonargparse",
|
|
||||||
]
|
|
||||||
DEV = [
|
|
||||||
"bumpversion",
|
|
||||||
"pre-commit",
|
|
||||||
]
|
|
||||||
DOCS = [
|
|
||||||
"recommonmark",
|
|
||||||
"sphinx",
|
|
||||||
"nbsphinx",
|
|
||||||
"ipykernel",
|
|
||||||
"sphinx_rtd_theme",
|
|
||||||
"sphinxcontrib-katex",
|
|
||||||
"sphinxcontrib-bibtex",
|
|
||||||
]
|
|
||||||
EXAMPLES = [
|
|
||||||
"matplotlib",
|
|
||||||
"scikit-learn",
|
|
||||||
]
|
|
||||||
TESTS = [
|
|
||||||
"codecov",
|
|
||||||
"pytest",
|
|
||||||
]
|
|
||||||
ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name=safe_name("prototorch_" + PLUGIN_NAME),
|
|
||||||
version="0.5.2",
|
|
||||||
description="Pre-packaged prototype-based "
|
|
||||||
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
|
||||||
long_description=long_description,
|
|
||||||
long_description_content_type="text/markdown",
|
|
||||||
author="Alexander Engelsberger",
|
|
||||||
author_email="engelsbe@hs-mittweida.de",
|
|
||||||
url=PROJECT_URL,
|
|
||||||
download_url=DOWNLOAD_URL,
|
|
||||||
license="MIT",
|
|
||||||
python_requires=">=3.7",
|
|
||||||
install_requires=INSTALL_REQUIRES,
|
|
||||||
extras_require={
|
|
||||||
"dev": DEV,
|
|
||||||
"examples": EXAMPLES,
|
|
||||||
"tests": TESTS,
|
|
||||||
"all": ALL,
|
|
||||||
},
|
|
||||||
classifiers=[
|
|
||||||
"Development Status :: 2 - Pre-Alpha",
|
|
||||||
"Environment :: Plugins",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"Intended Audience :: Education",
|
|
||||||
"Intended Audience :: Science/Research",
|
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Natural Language :: English",
|
|
||||||
"Programming Language :: Python :: 3",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.7",
|
|
||||||
"Operating System :: OS Independent",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
||||||
"Topic :: Software Development :: Libraries",
|
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
||||||
],
|
|
||||||
entry_points={
|
|
||||||
"prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}"
|
|
||||||
},
|
|
||||||
packages=find_namespace_packages(include=["prototorch.*"]),
|
|
||||||
zip_safe=False,
|
|
||||||
)
|
|
@ -36,4 +36,4 @@ from .unsupervised import (
|
|||||||
)
|
)
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
__version__ = "0.5.2"
|
__version__ = "0.7.1"
|
@ -71,7 +71,7 @@ class PrototypeModel(ProtoTorchBolt):
|
|||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||||
self.distance_layer = LambdaLayer(distance_fn)
|
self.distance_layer = LambdaLayer(distance_fn, name="distance_fn")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_prototypes(self):
|
def num_prototypes(self):
|
||||||
@ -186,26 +186,37 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
def log_acc(self, distances, targets, tag):
|
def log_acc(self, distances, targets, tag):
|
||||||
preds = self.predict_from_distances(distances)
|
preds = self.predict_from_distances(distances)
|
||||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
accuracy = torchmetrics.functional.accuracy(
|
||||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
preds.int(),
|
||||||
|
targets.int(),
|
||||||
|
"multiclass",
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
)
|
||||||
|
|
||||||
self.log(tag,
|
self.log(
|
||||||
|
tag,
|
||||||
accuracy,
|
accuracy,
|
||||||
on_step=False,
|
on_step=False,
|
||||||
on_epoch=True,
|
on_epoch=True,
|
||||||
prog_bar=True,
|
prog_bar=True,
|
||||||
logger=True)
|
logger=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
x, targets = batch
|
x, targets = batch
|
||||||
|
|
||||||
preds = self.predict(x)
|
preds = self.predict(x)
|
||||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
accuracy = torchmetrics.functional.accuracy(
|
||||||
|
preds.int(),
|
||||||
|
targets.int(),
|
||||||
|
"multiclass",
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
)
|
||||||
|
|
||||||
self.log("test_acc", accuracy)
|
self.log("test_acc", accuracy)
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchMixin(object):
|
class ProtoTorchMixin:
|
||||||
"""All mixins are ProtoTorchMixins."""
|
"""All mixins are ProtoTorchMixins."""
|
||||||
|
|
||||||
|
|
||||||
@ -216,7 +227,7 @@ class NonGradientMixin(ProtoTorchMixin):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.automatic_optimization = False
|
self.automatic_optimization = False
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -44,7 +44,7 @@ class CBC(SiameseGLVQ):
|
|||||||
probs = self.competition_layer(detections, reasonings)
|
probs = self.competition_layer(detections, reasonings)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_pred = self(x)
|
y_pred = self(x)
|
||||||
num_classes = self.num_classes
|
num_classes = self.num_classes
|
||||||
@ -52,17 +52,23 @@ class CBC(SiameseGLVQ):
|
|||||||
loss = self.loss(y_pred, y_true).mean()
|
loss = self.loss(y_pred, y_true).mean()
|
||||||
return y_pred, loss
|
return y_pred, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx):
|
||||||
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
y_pred, train_loss = self.shared_step(batch, batch_idx)
|
||||||
preds = torch.argmax(y_pred, dim=1)
|
preds = torch.argmax(y_pred, dim=1)
|
||||||
accuracy = torchmetrics.functional.accuracy(preds.int(),
|
accuracy = torchmetrics.functional.accuracy(
|
||||||
batch[1].int())
|
preds.int(),
|
||||||
self.log("train_acc",
|
batch[1].int(),
|
||||||
|
"multiclass",
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"train_acc",
|
||||||
accuracy,
|
accuracy,
|
||||||
on_step=False,
|
on_step=False,
|
||||||
on_epoch=True,
|
on_epoch=True,
|
||||||
prog_bar=True,
|
prog_bar=True,
|
||||||
logger=True)
|
logger=True,
|
||||||
|
)
|
||||||
return train_loss
|
return train_loss
|
||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x):
|
@ -39,7 +39,7 @@ def ltangent_distance(x, y, omegas):
|
|||||||
:param `torch.tensor` omegas: Three dimensional matrix
|
:param `torch.tensor` omegas: Three dimensional matrix
|
||||||
:rtype: `torch.tensor`
|
:rtype: `torch.tensor`
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
||||||
omegas, omegas.permute([0, 2, 1]))
|
omegas, omegas.permute([0, 2, 1]))
|
||||||
projected_x = x @ p
|
projected_x = x @ p
|
@ -66,15 +66,15 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
prototype_wr,
|
prototype_wr,
|
||||||
])
|
])
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x)
|
out = self.compute_distances(x)
|
||||||
_, plabels = self.proto_layer()
|
_, plabels = self.proto_layer()
|
||||||
loss = self.loss(out, y, plabels)
|
loss = self.loss(out, y, plabels)
|
||||||
return out, loss
|
return out, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx):
|
||||||
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
out, train_loss = self.shared_step(batch, batch_idx)
|
||||||
self.log_prototype_win_ratios(out)
|
self.log_prototype_win_ratios(out)
|
||||||
self.log("train_loss", train_loss)
|
self.log("train_loss", train_loss)
|
||||||
self.log_acc(out, batch[-1], tag="train_acc")
|
self.log_acc(out, batch[-1], tag="train_acc")
|
||||||
@ -99,10 +99,6 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
test_loss += batch_loss.item()
|
test_loss += batch_loss.item()
|
||||||
self.log("test_loss", test_loss)
|
self.log("test_loss", test_loss)
|
||||||
|
|
||||||
# TODO
|
|
||||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
|
||||||
class SiameseGLVQ(GLVQ):
|
class SiameseGLVQ(GLVQ):
|
||||||
"""GLVQ in a Siamese setting.
|
"""GLVQ in a Siamese setting.
|
||||||
@ -123,29 +119,9 @@ class SiameseGLVQ(GLVQ):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.both_path_gradients = both_path_gradients
|
self.both_path_gradients = both_path_gradients
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
|
||||||
lr=self.hparams["proto_lr"])
|
|
||||||
# Only add a backbone optimizer if backbone has trainable parameters
|
|
||||||
bb_params = list(self.backbone.parameters())
|
|
||||||
if (bb_params):
|
|
||||||
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
|
|
||||||
optimizers = [proto_opt, bb_opt]
|
|
||||||
else:
|
|
||||||
optimizers = [proto_opt]
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
schedulers = []
|
|
||||||
for optimizer in optimizers:
|
|
||||||
scheduler = self.lr_scheduler(optimizer,
|
|
||||||
**self.lr_scheduler_kwargs)
|
|
||||||
schedulers.append(scheduler)
|
|
||||||
return optimizers, schedulers
|
|
||||||
else:
|
|
||||||
return optimizers
|
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
|
|
||||||
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
||||||
@ -209,9 +185,12 @@ class GRLVQ(SiameseGLVQ):
|
|||||||
self.register_parameter("_relevances", Parameter(relevances))
|
self.register_parameter("_relevances", Parameter(relevances))
|
||||||
|
|
||||||
# Override the backbone
|
# Override the backbone
|
||||||
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
|
self.backbone = LambdaLayer(self._apply_relevances,
|
||||||
name="relevance scaling")
|
name="relevance scaling")
|
||||||
|
|
||||||
|
def _apply_relevances(self, x):
|
||||||
|
return x @ torch.diag(self._relevances)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def relevance_profile(self):
|
def relevance_profile(self):
|
||||||
return self._relevances.detach().cpu()
|
return self._relevances.detach().cpu()
|
||||||
@ -271,8 +250,6 @@ class GMLVQ(GLVQ):
|
|||||||
omega = omega_initializer.generate(self.hparams["input_dim"],
|
omega = omega_initializer.generate(self.hparams["input_dim"],
|
||||||
self.hparams["latent_dim"])
|
self.hparams["latent_dim"])
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
self.backbone = LambdaLayer(lambda x: x @ self._omega,
|
|
||||||
name="omega matrix")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrix(self):
|
def omega_matrix(self):
|
@ -34,7 +34,7 @@ class KNN(SupervisedPrototypeModel):
|
|||||||
labels_initializer=LiteralLabelsInitializer(targets))
|
labels_initializer=LiteralLabelsInitializer(targets))
|
||||||
self.competition_layer = KNNC(k=self.hparams.k)
|
self.competition_layer = KNNC(k=self.hparams.k)
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx):
|
||||||
return 1 # skip training step
|
return 1 # skip training step
|
||||||
|
|
||||||
def on_train_batch_start(self, train_batch, batch_idx):
|
def on_train_batch_start(self, train_batch, batch_idx):
|
@ -13,7 +13,7 @@ from .glvq import GLVQ
|
|||||||
class LVQ1(NonGradientMixin, GLVQ):
|
class LVQ1(NonGradientMixin, GLVQ):
|
||||||
"""Learning Vector Quantization 1."""
|
"""Learning Vector Quantization 1."""
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx):
|
||||||
protos, plables = self.proto_layer()
|
protos, plables = self.proto_layer()
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self.compute_distances(x)
|
dis = self.compute_distances(x)
|
||||||
@ -43,7 +43,7 @@ class LVQ1(NonGradientMixin, GLVQ):
|
|||||||
class LVQ21(NonGradientMixin, GLVQ):
|
class LVQ21(NonGradientMixin, GLVQ):
|
||||||
"""Learning Vector Quantization 2.1."""
|
"""Learning Vector Quantization 2.1."""
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx):
|
||||||
protos, plabels = self.proto_layer()
|
protos, plabels = self.proto_layer()
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
@ -100,7 +100,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
|
|||||||
lower_bound = (gamma * f.log()).sum()
|
lower_bound = (gamma * f.log()).sum()
|
||||||
return lower_bound
|
return lower_bound
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx):
|
||||||
protos, plabels = self.proto_layer()
|
protos, plabels = self.proto_layer()
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
@ -21,7 +21,7 @@ class CELVQ(GLVQ):
|
|||||||
# Loss
|
# Loss
|
||||||
self.loss = torch.nn.CrossEntropyLoss()
|
self.loss = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x) # [None, num_protos]
|
out = self.compute_distances(x) # [None, num_protos]
|
||||||
_, plabels = self.proto_layer()
|
_, plabels = self.proto_layer()
|
||||||
@ -63,7 +63,7 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
prediction[confidence < self.rejection_confidence] = -1
|
prediction[confidence < self.rejection_confidence] = -1
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.forward(x)
|
out = self.forward(x)
|
||||||
_, plabels = self.proto_layer()
|
_, plabels = self.proto_layer()
|
||||||
@ -123,7 +123,7 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
|||||||
self.loss = torch.nn.KLDivLoss()
|
self.loss = torch.nn.KLDivLoss()
|
||||||
|
|
||||||
# FIXME
|
# FIXME
|
||||||
# def training_step(self, batch, batch_idx, optimizer_idx=None):
|
# def training_step(self, batch, batch_idx):
|
||||||
# x, y = batch
|
# x, y = batch
|
||||||
# y_pred = self(x)
|
# y_pred = self(x)
|
||||||
# batch_loss = self.loss(y_pred, y)
|
# batch_loss = self.loss(y_pred, y)
|
@ -63,7 +63,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
|||||||
strict=False,
|
strict=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def training_epoch_end(self, training_step_outputs):
|
def on_training_epoch_end(self, training_step_outputs):
|
||||||
self._sigma = self.hparams.sigma * np.exp(
|
self._sigma = self.hparams.sigma * np.exp(
|
||||||
-self.current_epoch / self.trainer.max_epochs)
|
-self.current_epoch / self.trainer.max_epochs)
|
||||||
|
|
@ -1,195 +1,193 @@
|
|||||||
"""prototorch.models test suite."""
|
"""prototorch.models test suite."""
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch.models
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def test_glvq_model_build():
|
def test_glvq_model_build():
|
||||||
model = pt.models.GLVQ(
|
model = prototorch.models.GLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_glvq1_model_build():
|
def test_glvq1_model_build():
|
||||||
model = pt.models.GLVQ1(
|
model = prototorch.models.GLVQ1(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_glvq21_model_build():
|
def test_glvq21_model_build():
|
||||||
model = pt.models.GLVQ1(
|
model = prototorch.models.GLVQ1(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_gmlvq_model_build():
|
def test_gmlvq_model_build():
|
||||||
model = pt.models.GMLVQ(
|
model = prototorch.models.GMLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 2,
|
"input_dim": 2,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_grlvq_model_build():
|
def test_grlvq_model_build():
|
||||||
model = pt.models.GRLVQ(
|
model = prototorch.models.GRLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 2,
|
"input_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_gtlvq_model_build():
|
def test_gtlvq_model_build():
|
||||||
model = pt.models.GTLVQ(
|
model = prototorch.models.GTLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 4,
|
"input_dim": 4,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_lgmlvq_model_build():
|
def test_lgmlvq_model_build():
|
||||||
model = pt.models.LGMLVQ(
|
model = prototorch.models.LGMLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 4,
|
"input_dim": 4,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_image_glvq_model_build():
|
def test_image_glvq_model_build():
|
||||||
model = pt.models.ImageGLVQ(
|
model = prototorch.models.ImageGLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(16),
|
prototypes_initializer=prototorch.initializers.RNCI(16),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_image_gmlvq_model_build():
|
def test_image_gmlvq_model_build():
|
||||||
model = pt.models.ImageGMLVQ(
|
model = prototorch.models.ImageGMLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 16,
|
"input_dim": 16,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=pt.initializers.RNCI(16),
|
prototypes_initializer=prototorch.initializers.RNCI(16),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_image_gtlvq_model_build():
|
def test_image_gtlvq_model_build():
|
||||||
model = pt.models.ImageGMLVQ(
|
model = prototorch.models.ImageGMLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 16,
|
"input_dim": 16,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=pt.initializers.RNCI(16),
|
prototypes_initializer=prototorch.initializers.RNCI(16),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_siamese_glvq_model_build():
|
def test_siamese_glvq_model_build():
|
||||||
model = pt.models.SiameseGLVQ(
|
model = prototorch.models.SiameseGLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(4),
|
prototypes_initializer=prototorch.initializers.RNCI(4),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_siamese_gmlvq_model_build():
|
def test_siamese_gmlvq_model_build():
|
||||||
model = pt.models.SiameseGMLVQ(
|
model = prototorch.models.SiameseGMLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 4,
|
"input_dim": 4,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=pt.initializers.RNCI(4),
|
prototypes_initializer=prototorch.initializers.RNCI(4),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_siamese_gtlvq_model_build():
|
def test_siamese_gtlvq_model_build():
|
||||||
model = pt.models.SiameseGTLVQ(
|
model = prototorch.models.SiameseGTLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 4,
|
"input_dim": 4,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=pt.initializers.RNCI(4),
|
prototypes_initializer=prototorch.initializers.RNCI(4),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_knn_model_build():
|
def test_knn_model_build():
|
||||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
train_ds = prototorch.datasets.Iris(dims=[0, 2])
|
||||||
model = pt.models.KNN(dict(k=3), data=train_ds)
|
model = prototorch.models.KNN(dict(k=3), data=train_ds)
|
||||||
|
|
||||||
|
|
||||||
def test_lvq1_model_build():
|
def test_lvq1_model_build():
|
||||||
model = pt.models.LVQ1(
|
model = prototorch.models.LVQ1(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_lvq21_model_build():
|
def test_lvq21_model_build():
|
||||||
model = pt.models.LVQ21(
|
model = prototorch.models.LVQ21(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_median_lvq_model_build():
|
def test_median_lvq_model_build():
|
||||||
model = pt.models.MedianLVQ(
|
model = prototorch.models.MedianLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_celvq_model_build():
|
def test_celvq_model_build():
|
||||||
model = pt.models.CELVQ(
|
model = prototorch.models.CELVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_rslvq_model_build():
|
def test_rslvq_model_build():
|
||||||
model = pt.models.RSLVQ(
|
model = prototorch.models.RSLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_slvq_model_build():
|
def test_slvq_model_build():
|
||||||
model = pt.models.SLVQ(
|
model = prototorch.models.SLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_growing_neural_gas_model_build():
|
def test_growing_neural_gas_model_build():
|
||||||
model = pt.models.GrowingNeuralGas(
|
model = prototorch.models.GrowingNeuralGas(
|
||||||
{"num_prototypes": 5},
|
{"num_prototypes": 5},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_kohonen_som_model_build():
|
def test_kohonen_som_model_build():
|
||||||
model = pt.models.KohonenSOM(
|
model = prototorch.models.KohonenSOM(
|
||||||
{"shape": (3, 2)},
|
{"shape": (3, 2)},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_neural_gas_model_build():
|
def test_neural_gas_model_build():
|
||||||
model = pt.models.NeuralGas(
|
model = prototorch.models.NeuralGas(
|
||||||
{"num_prototypes": 5},
|
{"num_prototypes": 5},
|
||||||
prototypes_initializer=pt.initializers.RNCI(2),
|
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user