refactor(api)!: merge the new api changes into dev
This commit is contained in:
commit
d42693a441
@ -4,7 +4,10 @@ commit = True
|
|||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
serialize = {major}.{minor}.{patch}
|
serialize = {major}.{minor}.{patch}
|
||||||
|
message = bump: {current_version} → {new_version}
|
||||||
|
|
||||||
[bumpversion:file:setup.py]
|
[bumpversion:file:setup.py]
|
||||||
|
|
||||||
[bumpversion:file:./prototorch/models/__init__.py]
|
[bumpversion:file:./prototorch/models/__init__.py]
|
||||||
|
|
||||||
|
[bumpversion:file:./docs/source/conf.py]
|
||||||
|
17
.gitignore
vendored
17
.gitignore
vendored
@ -128,14 +128,19 @@ dmypy.json
|
|||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
# Datasets
|
|
||||||
datasets/
|
|
||||||
|
|
||||||
# PyTorch-Lightning
|
|
||||||
lightning_logs/
|
|
||||||
|
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|
||||||
|
# Vim
|
||||||
|
*~
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
# Pytorch Models or Weights
|
# Pytorch Models or Weights
|
||||||
# If necessary make exceptions for single pretrained models
|
# If necessary make exceptions for single pretrained models
|
||||||
*.pt
|
*.pt
|
||||||
|
|
||||||
|
# Artifacts created by ProtoTorch Models
|
||||||
|
datasets/
|
||||||
|
lightning_logs/
|
||||||
|
examples/_*.py
|
||||||
|
examples/_*.ipynb
|
||||||
|
@ -1,54 +1,53 @@
|
|||||||
# See https://pre-commit.com for more information
|
# See https://pre-commit.com for more information
|
||||||
# See https://pre-commit.com/hooks.html for more hooks
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
repos:
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
rev: v4.0.1
|
|
||||||
hooks:
|
|
||||||
- id: trailing-whitespace
|
|
||||||
- id: end-of-file-fixer
|
|
||||||
- id: check-yaml
|
|
||||||
- id: check-added-large-files
|
|
||||||
- id: check-ast
|
|
||||||
- id: check-case-conflict
|
|
||||||
|
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.0.1
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-ast
|
||||||
|
- id: check-case-conflict
|
||||||
|
|
||||||
- repo: https://github.com/myint/autoflake
|
- repo: https://github.com/myint/autoflake
|
||||||
rev: v1.4
|
rev: v1.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: autoflake
|
- id: autoflake
|
||||||
|
|
||||||
- repo: http://github.com/PyCQA/isort
|
- repo: http://github.com/PyCQA/isort
|
||||||
rev: 5.8.0
|
rev: 5.8.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: 'v0.902'
|
rev: v0.902
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
files: prototorch
|
files: prototorch
|
||||||
additional_dependencies: [types-pkg_resources]
|
additional_dependencies: [types-pkg_resources]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||||
rev: 'v0.31.0' # Use the sha / tag you want to point at
|
rev: v0.31.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||||
rev: v1.9.0 # Use the ref you want to point at
|
rev: v1.9.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: python-use-type-annotations
|
- id: python-use-type-annotations
|
||||||
- id: python-no-log-warn
|
- id: python-no-log-warn
|
||||||
- id: python-check-blanket-noqa
|
- id: python-check-blanket-noqa
|
||||||
|
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v2.19.4
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/si-cim/gitlint
|
||||||
rev: v2.19.4
|
rev: v0.15.2-unofficial
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: gitlint
|
||||||
|
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
||||||
- repo: https://github.com/jorisroovers/gitlint
|
|
||||||
rev: "v0.15.1"
|
|
||||||
hooks:
|
|
||||||
- id: gitlint
|
|
||||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
|
||||||
|
34
README.md
34
README.md
@ -20,23 +20,6 @@ pip install prototorch_models
|
|||||||
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
|
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
|
||||||
be available for use in your Python environment as `prototorch.models`.
|
be available for use in your Python environment as `prototorch.models`.
|
||||||
|
|
||||||
## Contribution
|
|
||||||
|
|
||||||
This repository contains definition for [git hooks](https://githooks.com).
|
|
||||||
[Pre-commit](https://pre-commit.com) is automatically installed as development
|
|
||||||
dependency with prototorch or you can install it manually with `pip install
|
|
||||||
pre-commit`.
|
|
||||||
|
|
||||||
Please install the hooks by running:
|
|
||||||
```bash
|
|
||||||
pre-commit install
|
|
||||||
pre-commit install --hook-type commit-msg
|
|
||||||
```
|
|
||||||
before creating the first commit.
|
|
||||||
|
|
||||||
The commit will fail if the commit message does not follow the specification
|
|
||||||
provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
|
|
||||||
|
|
||||||
## Available models
|
## Available models
|
||||||
|
|
||||||
### LVQ Family
|
### LVQ Family
|
||||||
@ -103,6 +86,23 @@ To assist in the development process, you may also find it useful to install
|
|||||||
please avoid installing Tensorflow in this environment. It is known to cause
|
please avoid installing Tensorflow in this environment. It is known to cause
|
||||||
problems with PyTorch-Lightning.**
|
problems with PyTorch-Lightning.**
|
||||||
|
|
||||||
|
## Contribution
|
||||||
|
|
||||||
|
This repository contains definition for [git hooks](https://githooks.com).
|
||||||
|
[Pre-commit](https://pre-commit.com) is automatically installed as development
|
||||||
|
dependency with prototorch or you can install it manually with `pip install
|
||||||
|
pre-commit`.
|
||||||
|
|
||||||
|
Please install the hooks by running:
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
pre-commit install --hook-type commit-msg
|
||||||
|
```
|
||||||
|
before creating the first commit.
|
||||||
|
|
||||||
|
The commit will fail if the commit message does not follow the specification
|
||||||
|
provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
### How do I update the plugin?
|
### How do I update the plugin?
|
||||||
|
@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.4.4"
|
release = "0.1.8"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -24,14 +23,18 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[2, 2, 2],
|
distribution=[1, 0, 3],
|
||||||
proto_lr=0.1,
|
margin=0.1,
|
||||||
|
proto_lr=0.01,
|
||||||
|
bb_lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.CBC(
|
model = pt.models.CBC(
|
||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=pt.components.SSI(train_ds, noise=0.01),
|
components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
|
||||||
|
reasonings_iniitializer=pt.initializers.
|
||||||
|
PurePositiveReasoningsInitializer(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -37,7 +36,7 @@ if __name__ == "__main__":
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.CELVQ(
|
model = pt.models.CELVQ(
|
||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=pt.components.Ones(2, scale=3),
|
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
|
@ -2,12 +2,11 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -24,7 +23,7 @@ if __name__ == "__main__":
|
|||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution={
|
distribution={
|
||||||
"num_classes": 3,
|
"num_classes": 3,
|
||||||
"prototypes_per_class": 4
|
"per_class": 4
|
||||||
},
|
},
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
)
|
)
|
||||||
@ -33,7 +32,7 @@ if __name__ == "__main__":
|
|||||||
model = pt.models.GLVQ(
|
model = pt.models.GLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
lr_scheduler=ExponentialLR,
|
lr_scheduler=ExponentialLR,
|
||||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||||
)
|
)
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -26,7 +25,6 @@ if __name__ == "__main__":
|
|||||||
distribution=(num_classes, prototypes_per_class),
|
distribution=(num_classes, prototypes_per_class),
|
||||||
transfer_function="swish_beta",
|
transfer_function="swish_beta",
|
||||||
transfer_beta=10.0,
|
transfer_beta=10.0,
|
||||||
# lr=0.1,
|
|
||||||
proto_lr=0.1,
|
proto_lr=0.1,
|
||||||
bb_lr=0.1,
|
bb_lr=0.1,
|
||||||
input_dim=2,
|
input_dim=2,
|
||||||
@ -37,7 +35,7 @@ if __name__ == "__main__":
|
|||||||
model = pt.models.GMLVQ(
|
model = pt.models.GMLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
prototype_initializer=pt.components.SSI(train_ds, noise=1e-2),
|
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
@ -47,12 +45,12 @@ if __name__ == "__main__":
|
|||||||
block=False,
|
block=False,
|
||||||
)
|
)
|
||||||
pruning = pt.models.PruneLoserPrototypes(
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
threshold=0.02,
|
threshold=0.01,
|
||||||
idle_epochs=10,
|
idle_epochs=10,
|
||||||
prune_quota_per_epoch=5,
|
prune_quota_per_epoch=5,
|
||||||
frequency=2,
|
frequency=5,
|
||||||
replace=True,
|
replace=True,
|
||||||
initializer=pt.components.SSI(train_ds, noise=1e-2),
|
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
es = pl.callbacks.EarlyStopping(
|
es = pl.callbacks.EarlyStopping(
|
||||||
@ -68,7 +66,7 @@ if __name__ == "__main__":
|
|||||||
args,
|
args,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
# es,
|
# es, # FIXME
|
||||||
pruning,
|
pruning,
|
||||||
],
|
],
|
||||||
terminate_on_nan=True,
|
terminate_on_nan=True,
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
"""GLVQ example using the Iris dataset."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Command-line arguments
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Dataset
|
|
||||||
train_ds = pt.datasets.Iris()
|
|
||||||
|
|
||||||
# Dataloaders
|
|
||||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
hparams = dict(
|
|
||||||
input_dim=4,
|
|
||||||
latent_dim=3,
|
|
||||||
distribution={
|
|
||||||
"num_classes": 3,
|
|
||||||
"prototypes_per_class": 2
|
|
||||||
},
|
|
||||||
proto_lr=0.0005,
|
|
||||||
bb_lr=0.0005,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the model
|
|
||||||
model = pt.models.GMLVQ(
|
|
||||||
hparams,
|
|
||||||
optimizer=torch.optim.Adam,
|
|
||||||
prototype_initializer=pt.components.SSI(train_ds),
|
|
||||||
lr_scheduler=ExponentialLR,
|
|
||||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
|
||||||
omega_initializer=pt.components.PCA(train_ds.data)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
|
||||||
#model.example_input_array = torch.zeros(4, 2)
|
|
||||||
|
|
||||||
# Callbacks
|
|
||||||
vis = pt.models.VisGMLVQ2D(data=train_ds, border=0.1)
|
|
||||||
|
|
||||||
# Setup trainer
|
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
|
||||||
args,
|
|
||||||
callbacks=[vis],
|
|
||||||
weights_summary="full",
|
|
||||||
accelerator="ddp",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
trainer.fit(model, train_loader)
|
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -30,7 +29,7 @@ if __name__ == "__main__":
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GrowingNeuralGas(
|
model = pt.models.GrowingNeuralGas(
|
||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=pt.components.Zeros(2),
|
prototypes_initializer=pt.initializers.ZCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
|
@ -2,25 +2,11 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
from prototorch.utils.colors import hex_to_rgb
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
|
|
||||||
def hex_to_rgb(hex_values):
|
|
||||||
for v in hex_values:
|
|
||||||
v = v.lstrip('#')
|
|
||||||
lv = len(v)
|
|
||||||
c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)]
|
|
||||||
yield c
|
|
||||||
|
|
||||||
|
|
||||||
def rgb_to_hex(rgb_values):
|
|
||||||
for v in rgb_values:
|
|
||||||
c = "%02x%02x%02x" % tuple(v)
|
|
||||||
yield c
|
|
||||||
|
|
||||||
|
|
||||||
class Vis2DColorSOM(pl.Callback):
|
class Vis2DColorSOM(pl.Callback):
|
||||||
@ -93,7 +79,7 @@ if __name__ == "__main__":
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.KohonenSOM(
|
model = pt.models.KohonenSOM(
|
||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=pt.components.Random(3),
|
prototypes_initializer=pt.initializers.RNCI(3),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
|
@ -2,23 +2,22 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
|
||||||
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
pl.utilities.seed.seed_everything(seed=2)
|
pl.utilities.seed.seed_everything(seed=2)
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
batch_size=256,
|
batch_size=256,
|
||||||
@ -32,8 +31,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.LGMLVQ(hparams,
|
model = pt.models.LGMLVQ(
|
||||||
prototype_initializer=pt.components.SMI(train_ds))
|
hparams,
|
||||||
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
|
)
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
@ -3,11 +3,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
|
|
||||||
def plot_matrix(matrix):
|
def plot_matrix(matrix):
|
||||||
title = "Lambda matrix"
|
title = "Lambda matrix"
|
||||||
@ -40,20 +39,19 @@ if __name__ == "__main__":
|
|||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution={
|
distribution={
|
||||||
"num_classes": 2,
|
"num_classes": 2,
|
||||||
"prototypes_per_class": 1
|
"per_class": 1,
|
||||||
},
|
},
|
||||||
input_dim=100,
|
input_dim=100,
|
||||||
latent_dim=2,
|
latent_dim=2,
|
||||||
proto_lr=0.0001,
|
proto_lr=0.001,
|
||||||
bb_lr=0.0001,
|
bb_lr=0.001,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.SiameseGMLVQ(
|
model = pt.models.SiameseGMLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
# optimizer=torch.optim.SGD,
|
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
|
|
||||||
class Backbone(torch.nn.Module):
|
class Backbone(torch.nn.Module):
|
||||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||||
@ -41,7 +40,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[1, 2, 2],
|
distribution=[3, 4, 5],
|
||||||
proto_lr=0.001,
|
proto_lr=0.001,
|
||||||
bb_lr=0.001,
|
bb_lr=0.001,
|
||||||
)
|
)
|
||||||
@ -52,7 +51,10 @@ if __name__ == "__main__":
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.LVQMLN(
|
model = pt.models.LVQMLN(
|
||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=pt.components.SSI(train_ds, transform=backbone),
|
prototypes_initializer=pt.initializers.SSCI(
|
||||||
|
train_ds,
|
||||||
|
transform=backbone,
|
||||||
|
),
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -67,11 +69,21 @@ if __name__ == "__main__":
|
|||||||
resolution=500,
|
resolution=500,
|
||||||
axis_off=True,
|
axis_off=True,
|
||||||
)
|
)
|
||||||
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
|
threshold=0.01,
|
||||||
|
idle_epochs=20,
|
||||||
|
prune_quota_per_epoch=2,
|
||||||
|
frequency=10,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
callbacks=[vis],
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
pruning,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -2,11 +2,9 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import Lambda
|
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
@ -28,19 +26,17 @@ if __name__ == "__main__":
|
|||||||
distribution=[2, 2, 3],
|
distribution=[2, 2, 3],
|
||||||
proto_lr=0.05,
|
proto_lr=0.05,
|
||||||
lambd=0.1,
|
lambd=0.1,
|
||||||
|
variance=1.0,
|
||||||
input_dim=2,
|
input_dim=2,
|
||||||
latent_dim=2,
|
latent_dim=2,
|
||||||
bb_lr=0.01,
|
bb_lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.probabilistic.PLVQ(
|
model = pt.models.RSLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
# prototype_initializer=pt.components.SMI(train_ds),
|
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
|
||||||
prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
|
|
||||||
# prototype_initializer=pt.components.Zeros(2),
|
|
||||||
# prototype_initializer=pt.components.Ones(2, scale=2.0),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
@ -50,7 +46,7 @@ if __name__ == "__main__":
|
|||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisSiameseGLVQ2D(data=train_ds)
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
|
|
||||||
class Backbone(torch.nn.Module):
|
class Backbone(torch.nn.Module):
|
||||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||||
@ -52,7 +51,7 @@ if __name__ == "__main__":
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.SiameseGLVQ(
|
model = pt.models.SiameseGLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
both_path_gradients=False,
|
both_path_gradients=False,
|
||||||
)
|
)
|
||||||
|
84
examples/warm_starting.py
Normal file
84
examples/warm_starting.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Prepare the data
|
||||||
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||||
|
|
||||||
|
# Initialize the gng
|
||||||
|
gng = pt.models.GrowingNeuralGas(
|
||||||
|
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
|
||||||
|
prototypes_initializer=pt.initializers.ZCI(2),
|
||||||
|
lr_scheduler=ExponentialLR,
|
||||||
|
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
es = pl.callbacks.EarlyStopping(
|
||||||
|
monitor="loss",
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=20,
|
||||||
|
mode="min",
|
||||||
|
verbose=False,
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup trainer for GNG
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=200,
|
||||||
|
callbacks=[es],
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(gng, train_loader)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(
|
||||||
|
distribution=[],
|
||||||
|
lr=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warm-start prototypes
|
||||||
|
knn = pt.models.KNN(dict(k=1), data=train_ds)
|
||||||
|
prototypes = gng.prototypes
|
||||||
|
plabels = knn.predict(prototypes)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.GLVQ(
|
||||||
|
hparams,
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
prototypes_initializer=pt.initializers.LCI(prototypes),
|
||||||
|
labels_initializer=pt.initializers.LLI(plabels),
|
||||||
|
lr_scheduler=ExponentialLR,
|
||||||
|
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[vis],
|
||||||
|
weights_summary="full",
|
||||||
|
accelerator="ddp",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
@ -4,8 +4,19 @@ from importlib.metadata import PackageNotFoundError, version
|
|||||||
|
|
||||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||||
from .cbc import CBC, ImageCBC
|
from .cbc import CBC, ImageCBC
|
||||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
from .glvq import (
|
||||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
GLVQ,
|
||||||
|
GLVQ1,
|
||||||
|
GLVQ21,
|
||||||
|
GMLVQ,
|
||||||
|
GRLVQ,
|
||||||
|
LGMLVQ,
|
||||||
|
LVQMLN,
|
||||||
|
ImageGLVQ,
|
||||||
|
ImageGMLVQ,
|
||||||
|
SiameseGLVQ,
|
||||||
|
SiameseGMLVQ,
|
||||||
|
)
|
||||||
from .knn import KNN
|
from .knn import KNN
|
||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||||
|
@ -5,9 +5,13 @@ from typing import Final, final
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.components import Components, LabeledComponents
|
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from ..core.competitions import WTAC
|
||||||
from prototorch.modules import WTAC, LambdaLayer
|
from ..core.components import Components, LabeledComponents
|
||||||
|
from ..core.distances import euclidean_distance
|
||||||
|
from ..core.initializers import LabelsInitializer
|
||||||
|
from ..core.pooling import stratified_min_pooling
|
||||||
|
from ..nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchMixin(object):
|
class ProtoTorchMixin(object):
|
||||||
@ -85,13 +89,11 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
|||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||||
initialized_prototypes = kwargs.get("initialized_prototypes", None)
|
if prototypes_initializer is not None:
|
||||||
if prototype_initializer is not None or initialized_prototypes is not None:
|
|
||||||
self.proto_layer = Components(
|
self.proto_layer = Components(
|
||||||
self.hparams.num_prototypes,
|
self.hparams.num_prototypes,
|
||||||
initializer=prototype_initializer,
|
initializer=prototypes_initializer,
|
||||||
initialized_components=initialized_prototypes,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
@ -109,23 +111,24 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||||
initialized_prototypes = kwargs.get("initialized_prototypes", None)
|
labels_initializer = kwargs.get("labels_initializer",
|
||||||
if prototype_initializer is not None or initialized_prototypes is not None:
|
LabelsInitializer())
|
||||||
|
if prototypes_initializer is not None:
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
distribution=self.hparams.distribution,
|
distribution=self.hparams.distribution,
|
||||||
initializer=prototype_initializer,
|
components_initializer=prototypes_initializer,
|
||||||
initialized_components=initialized_prototypes,
|
labels_initializer=labels_initializer,
|
||||||
)
|
)
|
||||||
self.competition_layer = WTAC()
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prototype_labels(self):
|
def prototype_labels(self):
|
||||||
return self.proto_layer.component_labels.detach().cpu()
|
return self.proto_layer.labels.detach().cpu()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
return len(self.proto_layer.distribution)
|
return self.proto_layer.num_classes
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
@ -134,15 +137,14 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
distances = self.compute_distances(x)
|
distances = self.compute_distances(x)
|
||||||
y_pred = self.predict_from_distances(distances)
|
plabels = self.proto_layer.labels
|
||||||
# TODO
|
winning = stratified_min_pooling(distances, plabels)
|
||||||
y_pred = torch.eye(self.num_classes, device=self.device)[
|
y_pred = torch.nn.functional.softmin(winning)
|
||||||
y_pred.long()] # depends on labels {0,...,num_classes}
|
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
def predict_from_distances(self, distances):
|
def predict_from_distances(self, distances):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.labels
|
||||||
y_pred = self.competition_layer(distances, plabels)
|
y_pred = self.competition_layer(distances, plabels)
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
|
@ -4,8 +4,9 @@ import logging
|
|||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from prototorch.components import Components
|
|
||||||
|
|
||||||
|
from ..core.components import Components
|
||||||
|
from ..core.initializers import LiteralCompInitializer
|
||||||
from .extras import ConnectionTopology
|
from .extras import ConnectionTopology
|
||||||
|
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ class PruneLoserPrototypes(pl.Callback):
|
|||||||
prune_quota_per_epoch=-1,
|
prune_quota_per_epoch=-1,
|
||||||
frequency=1,
|
frequency=1,
|
||||||
replace=False,
|
replace=False,
|
||||||
initializer=None,
|
prototypes_initializer=None,
|
||||||
verbose=False):
|
verbose=False):
|
||||||
self.threshold = threshold # minimum win ratio
|
self.threshold = threshold # minimum win ratio
|
||||||
self.idle_epochs = idle_epochs # epochs to wait before pruning
|
self.idle_epochs = idle_epochs # epochs to wait before pruning
|
||||||
@ -24,7 +25,7 @@ class PruneLoserPrototypes(pl.Callback):
|
|||||||
self.frequency = frequency
|
self.frequency = frequency
|
||||||
self.replace = replace
|
self.replace = replace
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.initializer = initializer
|
self.prototypes_initializer = prototypes_initializer
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
if (trainer.current_epoch + 1) < self.idle_epochs:
|
if (trainer.current_epoch + 1) < self.idle_epochs:
|
||||||
@ -55,8 +56,9 @@ class PruneLoserPrototypes(pl.Callback):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Re-adding pruned prototypes...")
|
print(f"Re-adding pruned prototypes...")
|
||||||
print(f"{distribution=}")
|
print(f"{distribution=}")
|
||||||
pl_module.add_prototypes(distribution=distribution,
|
pl_module.add_prototypes(
|
||||||
initializer=self.initializer)
|
distribution=distribution,
|
||||||
|
components_initializer=self.prototypes_initializer)
|
||||||
new_num_protos = pl_module.num_prototypes
|
new_num_protos = pl_module.num_prototypes
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"`num_prototypes` changed from {cur_num_protos} "
|
print(f"`num_prototypes` changed from {cur_num_protos} "
|
||||||
@ -116,7 +118,8 @@ class GNGCallback(pl.Callback):
|
|||||||
|
|
||||||
# Add component
|
# Add component
|
||||||
pl_module.proto_layer.add_components(
|
pl_module.proto_layer.add_components(
|
||||||
initialized_components=new_component.unsqueeze(0))
|
None,
|
||||||
|
initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
|
||||||
|
|
||||||
# Adjust Topology
|
# Adjust Topology
|
||||||
topology.add_prototype()
|
topology.add_prototype()
|
||||||
|
@ -1,49 +1,54 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
|
||||||
|
from ..core.competitions import CBCC
|
||||||
|
from ..core.components import ReasoningComponents
|
||||||
|
from ..core.initializers import RandomReasoningsInitializer
|
||||||
|
from ..core.losses import MarginLoss
|
||||||
|
from ..core.similarities import euclidean_similarity
|
||||||
|
from ..nn.wrappers import LambdaLayer
|
||||||
from .abstract import ImagePrototypesMixin
|
from .abstract import ImagePrototypesMixin
|
||||||
from .extras import (CosineSimilarity, MarginLoss, ReasoningLayer,
|
|
||||||
euclidean_similarity, rescaled_cosine_similarity,
|
|
||||||
shift_activation)
|
|
||||||
from .glvq import SiameseGLVQ
|
from .glvq import SiameseGLVQ
|
||||||
|
|
||||||
|
|
||||||
class CBC(SiameseGLVQ):
|
class CBC(SiameseGLVQ):
|
||||||
"""Classification-By-Components."""
|
"""Classification-By-Components."""
|
||||||
def __init__(self, hparams, margin=0.1, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.margin = margin
|
|
||||||
self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
|
||||||
num_components = self.components.shape[0]
|
|
||||||
self.reasoning_layer = ReasoningLayer(num_components=num_components,
|
|
||||||
num_classes=self.num_classes)
|
|
||||||
self.component_layer = self.proto_layer
|
|
||||||
|
|
||||||
@property
|
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
||||||
def components(self):
|
components_initializer = kwargs.get("components_initializer", None)
|
||||||
return self.prototypes
|
reasonings_initializer = kwargs.get("reasonings_initializer",
|
||||||
|
RandomReasoningsInitializer())
|
||||||
|
self.components_layer = ReasoningComponents(
|
||||||
|
self.hparams.distribution,
|
||||||
|
components_initializer=components_initializer,
|
||||||
|
reasonings_initializer=reasonings_initializer,
|
||||||
|
)
|
||||||
|
self.similarity_layer = LambdaLayer(similarity_fn)
|
||||||
|
self.competition_layer = CBCC()
|
||||||
|
|
||||||
@property
|
# Namespace hook
|
||||||
def reasonings(self):
|
self.proto_layer = self.components_layer
|
||||||
return self.reasoning_layer.reasonings.cpu()
|
|
||||||
|
self.loss = MarginLoss(self.hparams.margin)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
components, _ = self.component_layer()
|
components, reasonings = self.components_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
latent_components = self.backbone(components)
|
latent_components = self.backbone(components)
|
||||||
self.backbone.requires_grad_(True)
|
self.backbone.requires_grad_(True)
|
||||||
detections = self.similarity_fn(latent_x, latent_components)
|
detections = self.similarity_layer(latent_x, latent_components)
|
||||||
probs = self.reasoning_layer(detections)
|
probs = self.competition_layer(detections, reasonings)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
# x = x.view(x.size(0), -1)
|
|
||||||
y_pred = self(x)
|
y_pred = self(x)
|
||||||
num_classes = self.reasoning_layer.num_classes
|
num_classes = self.num_classes
|
||||||
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
|
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
|
||||||
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
|
loss = self.loss(y_pred, y_true).mean(dim=0)
|
||||||
return y_pred, loss
|
return y_pred, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
@ -70,7 +75,3 @@ class ImageCBC(ImagePrototypesMixin, CBC):
|
|||||||
"""CBC model that constrains the components to the range [0, 1] by
|
"""CBC model that constrains the components to the range [0, 1] by
|
||||||
clamping after updates.
|
clamping after updates.
|
||||||
"""
|
"""
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
super().__init__(hparams, **kwargs)
|
|
||||||
# Namespace hook
|
|
||||||
self.proto_layer = self.component_layer
|
|
||||||
|
@ -5,23 +5,32 @@ Modules not yet available in prototorch go here temporarily.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.distances import euclidean_distance
|
|
||||||
from prototorch.functions.similarities import cosine_similarity
|
from ..core.similarities import gaussian
|
||||||
|
|
||||||
|
|
||||||
def rescaled_cosine_similarity(x, y):
|
def rank_scaled_gaussian(distances, lambd):
|
||||||
"""Cosine Similarity rescaled to [0, 1]."""
|
order = torch.argsort(distances, dim=1)
|
||||||
similarities = cosine_similarity(x, y)
|
ranks = torch.argsort(order, dim=1)
|
||||||
return (similarities + 1.0) / 2.0
|
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||||
|
|
||||||
|
|
||||||
def shift_activation(x):
|
class GaussianPrior(torch.nn.Module):
|
||||||
return (x + 1.0) / 2.0
|
def __init__(self, variance):
|
||||||
|
super().__init__()
|
||||||
|
self.variance = variance
|
||||||
|
|
||||||
|
def forward(self, distances):
|
||||||
|
return gaussian(distances, self.variance)
|
||||||
|
|
||||||
|
|
||||||
def euclidean_similarity(x, y, variance=1.0):
|
class RankScaledGaussianPrior(torch.nn.Module):
|
||||||
d = euclidean_distance(x, y)
|
def __init__(self, lambd):
|
||||||
return torch.exp(-(d * d) / (2 * variance))
|
super().__init__()
|
||||||
|
self.lambd = lambd
|
||||||
|
|
||||||
|
def forward(self, distances):
|
||||||
|
return rank_scaled_gaussian(distances, self.lambd)
|
||||||
|
|
||||||
|
|
||||||
class ConnectionTopology(torch.nn.Module):
|
class ConnectionTopology(torch.nn.Module):
|
||||||
@ -79,64 +88,3 @@ class ConnectionTopology(torch.nn.Module):
|
|||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f"(agelimit): ({self.agelimit})"
|
return f"(agelimit): ({self.agelimit})"
|
||||||
|
|
||||||
|
|
||||||
class CosineSimilarity(torch.nn.Module):
|
|
||||||
def __init__(self, activation=shift_activation):
|
|
||||||
super().__init__()
|
|
||||||
self.activation = activation
|
|
||||||
|
|
||||||
def forward(self, x, y):
|
|
||||||
epsilon = torch.finfo(x.dtype).eps
|
|
||||||
normed_x = (x / x.pow(2).sum(dim=tuple(range(
|
|
||||||
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
|
||||||
start_dim=1)
|
|
||||||
normed_y = (y / y.pow(2).sum(dim=tuple(range(
|
|
||||||
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
|
||||||
start_dim=1)
|
|
||||||
# normed_x = (x / torch.linalg.norm(x, dim=1))
|
|
||||||
diss = torch.inner(normed_x, normed_y)
|
|
||||||
return self.activation(diss)
|
|
||||||
|
|
||||||
|
|
||||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
|
||||||
def __init__(self,
|
|
||||||
margin=0.3,
|
|
||||||
size_average=None,
|
|
||||||
reduce=None,
|
|
||||||
reduction="mean"):
|
|
||||||
super().__init__(size_average, reduce, reduction)
|
|
||||||
self.margin = margin
|
|
||||||
|
|
||||||
def forward(self, input_, target):
|
|
||||||
dp = torch.sum(target * input_, dim=-1)
|
|
||||||
dm = torch.max(input_ - target, dim=-1).values
|
|
||||||
return torch.nn.functional.relu(dm - dp + self.margin)
|
|
||||||
|
|
||||||
|
|
||||||
class ReasoningLayer(torch.nn.Module):
|
|
||||||
def __init__(self, num_components, num_classes, num_replicas=1):
|
|
||||||
super().__init__()
|
|
||||||
self.num_replicas = num_replicas
|
|
||||||
self.num_classes = num_classes
|
|
||||||
probabilities_init = torch.zeros(2, 1, num_components,
|
|
||||||
self.num_classes)
|
|
||||||
probabilities_init.uniform_(0.4, 0.6)
|
|
||||||
# TODO Use `self.register_parameter("param", Paramater(param))` instead
|
|
||||||
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reasonings(self):
|
|
||||||
pk = self.reasoning_probabilities[0]
|
|
||||||
nk = (1 - pk) * self.reasoning_probabilities[1]
|
|
||||||
ik = 1 - pk - nk
|
|
||||||
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
|
||||||
return img.unsqueeze(1)
|
|
||||||
|
|
||||||
def forward(self, detections):
|
|
||||||
pk = self.reasoning_probabilities[0].clamp(0, 1)
|
|
||||||
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
|
|
||||||
numerator = (detections @ (pk - nk)) + nk.sum(1)
|
|
||||||
probs = numerator / (pk + nk).sum(1)
|
|
||||||
probs = probs.squeeze(0)
|
|
||||||
return probs
|
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
"""Models based on the GLVQ framework."""
|
"""Models based on the GLVQ framework."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.activations import get_activation
|
|
||||||
from prototorch.functions.competitions import wtac
|
|
||||||
from prototorch.functions.distances import (lomega_distance, omega_distance,
|
|
||||||
squared_euclidean_distance)
|
|
||||||
from prototorch.functions.helper import get_flat
|
|
||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
|
||||||
from prototorch.components import LinearMapping
|
|
||||||
from prototorch.modules import LambdaLayer, LossLayer
|
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from ..core.competitions import wtac
|
||||||
|
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
||||||
|
from ..core.initializers import EyeTransformInitializer
|
||||||
|
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
|
from ..nn.activations import get_activation
|
||||||
|
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||||
|
|
||||||
|
|
||||||
@ -30,9 +28,6 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
# Loss
|
# Loss
|
||||||
self.loss = LossLayer(glvq_loss)
|
self.loss = LossLayer(glvq_loss)
|
||||||
|
|
||||||
# Prototype metrics
|
|
||||||
self.initialize_prototype_win_ratios()
|
|
||||||
|
|
||||||
def initialize_prototype_win_ratios(self):
|
def initialize_prototype_win_ratios(self):
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"prototype_win_ratios",
|
"prototype_win_ratios",
|
||||||
@ -59,7 +54,7 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x)
|
out = self.compute_distances(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.labels
|
||||||
mu = self.loss(out, y, prototype_labels=plabels)
|
mu = self.loss(out, y, prototype_labels=plabels)
|
||||||
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
|
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
|
||||||
loss = batch_loss.sum(dim=0)
|
loss = batch_loss.sum(dim=0)
|
||||||
@ -135,7 +130,7 @@ class SiameseGLVQ(GLVQ):
|
|||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = get_flat(x, protos)
|
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
latent_protos = self.backbone(protos)
|
latent_protos = self.backbone(protos)
|
||||||
@ -240,17 +235,13 @@ class GMLVQ(GLVQ):
|
|||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
omega_initializer = kwargs.get("omega_initializer", None)
|
omega_initializer = kwargs.get("omega_initializer",
|
||||||
initialized_omega = kwargs.get("initialized_omega", None)
|
EyeTransformInitializer())
|
||||||
if omega_initializer is not None or initialized_omega is not None:
|
omega = omega_initializer.generate(self.hparams.input_dim,
|
||||||
self.omega_layer = LinearMapping(
|
self.hparams.latent_dim)
|
||||||
mapping_shape=(self.hparams.input_dim, self.hparams.latent_dim),
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
initializer=omega_initializer,
|
self.backbone = LambdaLayer(lambda x: x @ self._omega,
|
||||||
initialized_linearmapping=initialized_omega,
|
name="omega matrix")
|
||||||
)
|
|
||||||
|
|
||||||
self.register_parameter("_omega", Parameter(self.omega_layer.mapping))
|
|
||||||
self.backbone = LambdaLayer(lambda x: x @ self._omega, name = "omega matrix")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrix(self):
|
def omega_matrix(self):
|
||||||
@ -264,24 +255,6 @@ class GMLVQ(GLVQ):
|
|||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
||||||
|
|
||||||
def predict_latent(self, x, map_protos=True):
|
|
||||||
"""Predict `x` assuming it is already embedded in the latent space.
|
|
||||||
|
|
||||||
Only the prototypes are embedded in the latent space using the
|
|
||||||
backbone.
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
protos, plabels = self.proto_layer()
|
|
||||||
if map_protos:
|
|
||||||
protos = self.backbone(protos)
|
|
||||||
d = squared_euclidean_distance(x, protos)
|
|
||||||
y_pred = wtac(d, plabels)
|
|
||||||
return y_pred
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LGMLVQ(GMLVQ):
|
class LGMLVQ(GMLVQ):
|
||||||
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from prototorch.components import LabeledComponents
|
from ..core.competitions import KNNC
|
||||||
from prototorch.modules import KNNC
|
from ..core.components import LabeledComponents
|
||||||
|
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
|
||||||
|
from ..utils.utils import parse_data_arg
|
||||||
from .abstract import SupervisedPrototypeModel
|
from .abstract import SupervisedPrototypeModel
|
||||||
|
|
||||||
|
|
||||||
@ -19,9 +20,13 @@ class KNN(SupervisedPrototypeModel):
|
|||||||
data = kwargs.get("data", None)
|
data = kwargs.get("data", None)
|
||||||
if data is None:
|
if data is None:
|
||||||
raise ValueError("KNN requires data, but was not provided!")
|
raise ValueError("KNN requires data, but was not provided!")
|
||||||
|
data, targets = parse_data_arg(data)
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
self.proto_layer = LabeledComponents(initialized_components=data)
|
self.proto_layer = LabeledComponents(
|
||||||
|
distribution=[],
|
||||||
|
components_initializer=LiteralCompInitializer(data),
|
||||||
|
labels_initializer=LiteralLabelsInitializer(targets))
|
||||||
self.competition_layer = KNNC(k=self.hparams.k)
|
self.competition_layer = KNNC(k=self.hparams.k)
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""LVQ models that are optimized using non-gradient methods."""
|
"""LVQ models that are optimized using non-gradient methods."""
|
||||||
|
|
||||||
from prototorch.functions.losses import _get_dp_dm
|
from ..core.losses import _get_dp_dm
|
||||||
|
|
||||||
from .abstract import NonGradientMixin
|
from .abstract import NonGradientMixin
|
||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
|
||||||
@ -10,7 +9,7 @@ class LVQ1(NonGradientMixin, GLVQ):
|
|||||||
"""Learning Vector Quantization 1."""
|
"""Learning Vector Quantization 1."""
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
protos = self.proto_layer.components
|
protos = self.proto_layer.components
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.labels
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self.compute_distances(x)
|
dis = self.compute_distances(x)
|
||||||
@ -29,6 +28,8 @@ class LVQ1(NonGradientMixin, GLVQ):
|
|||||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||||
strict=False)
|
strict=False)
|
||||||
|
|
||||||
|
print(f"{dis=}")
|
||||||
|
print(f"{y=}")
|
||||||
# Logging
|
# Logging
|
||||||
self.log_acc(dis, y, tag="train_acc")
|
self.log_acc(dis, y, tag="train_acc")
|
||||||
|
|
||||||
@ -39,7 +40,7 @@ class LVQ21(NonGradientMixin, GLVQ):
|
|||||||
"""Learning Vector Quantization 2.1."""
|
"""Learning Vector Quantization 2.1."""
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
protos = self.proto_layer.components
|
protos = self.proto_layer.components
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.labels
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self.compute_distances(x)
|
dis = self.compute_distances(x)
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
"""Probabilistic GLVQ methods"""
|
"""Probabilistic GLVQ methods"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.losses import nllr_loss, rslvq_loss
|
|
||||||
from prototorch.functions.pooling import (stratified_min_pooling,
|
|
||||||
stratified_sum_pooling)
|
|
||||||
from prototorch.functions.transforms import (GaussianPrior,
|
|
||||||
RankScaledGaussianPrior)
|
|
||||||
from prototorch.modules import LambdaLayer, LossLayer
|
|
||||||
|
|
||||||
|
from ..core.losses import nllr_loss, rslvq_loss
|
||||||
|
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||||
|
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||||
|
from .extras import GaussianPrior, RankScaledGaussianPrior
|
||||||
from .glvq import GLVQ, SiameseGMLVQ
|
from .glvq import GLVQ, SiameseGMLVQ
|
||||||
|
|
||||||
|
|
||||||
@ -22,7 +20,7 @@ class CELVQ(GLVQ):
|
|||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x) # [None, num_protos]
|
out = self.compute_distances(x) # [None, num_protos]
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.labels
|
||||||
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
|
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
|
||||||
probs = -1.0 * winning
|
probs = -1.0 * winning
|
||||||
batch_loss = self.loss(probs, y.long())
|
batch_loss = self.loss(probs, y.long())
|
||||||
@ -56,7 +54,7 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.forward(x)
|
out = self.forward(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.labels
|
||||||
batch_loss = self.loss(out, y, plabels)
|
batch_loss = self.loss(out, y, plabels)
|
||||||
loss = batch_loss.sum(dim=0)
|
loss = batch_loss.sum(dim=0)
|
||||||
return loss
|
return loss
|
||||||
@ -89,11 +87,10 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
|||||||
self.hparams.lambd)
|
self.hparams.lambd)
|
||||||
self.loss = torch.nn.KLDivLoss()
|
self.loss = torch.nn.KLDivLoss()
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
# FIXME
|
||||||
x, y = batch
|
# def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
out = self.forward(x)
|
# x, y = batch
|
||||||
y_dist = torch.nn.functional.one_hot(
|
# y_pred = self(x)
|
||||||
y.long(), num_classes=self.num_classes).float()
|
# batch_loss = self.loss(y_pred, y)
|
||||||
batch_loss = self.loss(out, y_dist)
|
# loss = batch_loss.sum(dim=0)
|
||||||
loss = batch_loss.sum(dim=0)
|
# return loss
|
||||||
return loss
|
|
||||||
|
@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.competitions import wtac
|
|
||||||
from prototorch.functions.distances import squared_euclidean_distance
|
|
||||||
from prototorch.modules import LambdaLayer
|
|
||||||
from prototorch.modules.losses import NeuralGasEnergy
|
|
||||||
|
|
||||||
|
from ..core.competitions import wtac
|
||||||
|
from ..core.distances import squared_euclidean_distance
|
||||||
|
from ..core.losses import NeuralGasEnergy
|
||||||
|
from ..nn.wrappers import LambdaLayer
|
||||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
||||||
from .callbacks import GNGCallback
|
from .callbacks import GNGCallback
|
||||||
from .extras import ConnectionTopology
|
from .extras import ConnectionTopology
|
||||||
|
@ -7,6 +7,8 @@ import torchvision
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from ..utils.utils import mesh2d
|
||||||
|
|
||||||
|
|
||||||
class Vis2DAbstract(pl.Callback):
|
class Vis2DAbstract(pl.Callback):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -73,23 +75,7 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
ax.axis("off")
|
ax.axis("off")
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
def get_mesh_input(self, x):
|
def plot_data(self, ax, x, y):
|
||||||
x_shift = self.border * np.ptp(x[:, 0])
|
|
||||||
y_shift = self.border * np.ptp(x[:, 1])
|
|
||||||
x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
|
|
||||||
y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
|
|
||||||
xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution),
|
|
||||||
np.linspace(y_min, y_max, self.resolution))
|
|
||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
||||||
return mesh_input, xx, yy
|
|
||||||
|
|
||||||
def perform_pca_2D(self, data):
|
|
||||||
(_, eigVal, eigVec) = torch.pca_lowrank(data, q=2)
|
|
||||||
return data @ eigVec
|
|
||||||
|
|
||||||
def plot_data(self, ax, x, y, pca=False):
|
|
||||||
if pca:
|
|
||||||
x = self.perform_pca_2D(x)
|
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
x[:, 0],
|
x[:, 0],
|
||||||
x[:, 1],
|
x[:, 1],
|
||||||
@ -100,9 +86,7 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
s=30,
|
s=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
def plot_protos(self, ax, protos, plabels, pca=False):
|
def plot_protos(self, ax, protos, plabels):
|
||||||
if pca:
|
|
||||||
protos = self.perform_pca_2D(protos)
|
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
protos[:, 0],
|
protos[:, 0],
|
||||||
protos[:, 1],
|
protos[:, 1],
|
||||||
@ -146,7 +130,7 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||||
_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||||
y_pred = pl_module.predict(mesh_input)
|
y_pred = pl_module.predict(mesh_input)
|
||||||
@ -181,9 +165,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
if self.show_protos:
|
if self.show_protos:
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||||
else:
|
else:
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x_train)
|
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
|
||||||
_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
||||||
y_pred = pl_module.predict_latent(mesh_input,
|
y_pred = pl_module.predict_latent(mesh_input,
|
||||||
@ -194,50 +178,6 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
self.log_and_display(trainer, pl_module)
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
|
|
||||||
class VisGMLVQ2D(Vis2DAbstract):
|
|
||||||
def __init__(self, *args, map_protos=True, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.map_protos = map_protos
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
|
||||||
if not self.precheck(trainer):
|
|
||||||
return True
|
|
||||||
|
|
||||||
protos = pl_module.prototypes
|
|
||||||
plabels = pl_module.prototype_labels
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
|
||||||
device = pl_module.device
|
|
||||||
with torch.no_grad():
|
|
||||||
x_train = pl_module.backbone(torch.Tensor(x_train).to(device))
|
|
||||||
x_train = x_train.cpu().detach()
|
|
||||||
if self.map_protos:
|
|
||||||
with torch.no_grad():
|
|
||||||
protos = pl_module.backbone(torch.Tensor(protos).to(device))
|
|
||||||
protos = protos.cpu().detach()
|
|
||||||
ax = self.setup_ax()
|
|
||||||
if x_train.shape[1] > 2:
|
|
||||||
self.plot_data(ax, x_train, y_train, pca=True)
|
|
||||||
else:
|
|
||||||
self.plot_data(ax, x_train, y_train, pca=False)
|
|
||||||
if self.show_protos:
|
|
||||||
if protos.shape[1] > 2:
|
|
||||||
self.plot_protos(ax, protos, plabels, pca=True)
|
|
||||||
else:
|
|
||||||
self.plot_protos(ax, protos, plabels, pca=False)
|
|
||||||
### something to work on: meshgrid with pca
|
|
||||||
# x = np.vstack((x_train, protos))
|
|
||||||
# mesh_input, xx, yy = self.get_mesh_input(x)
|
|
||||||
#else:
|
|
||||||
# mesh_input, xx, yy = self.get_mesh_input(x_train)
|
|
||||||
#_components = pl_module.proto_layer._components
|
|
||||||
#mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
|
||||||
#y_pred = pl_module.predict_latent(mesh_input,
|
|
||||||
# map_protos=self.map_protos)
|
|
||||||
#y_pred = y_pred.cpu().reshape(xx.shape)
|
|
||||||
#ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
|
||||||
self.log_and_display(trainer, pl_module)
|
|
||||||
|
|
||||||
|
|
||||||
class VisCBC2D(Vis2DAbstract):
|
class VisCBC2D(Vis2DAbstract):
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
if not self.precheck(trainer):
|
if not self.precheck(trainer):
|
||||||
@ -250,8 +190,8 @@ class VisCBC2D(Vis2DAbstract):
|
|||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, "w")
|
self.plot_protos(ax, protos, "w")
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||||
_components = pl_module.component_layer._components
|
_components = pl_module.components_layer._components
|
||||||
y_pred = pl_module.predict(
|
y_pred = pl_module.predict(
|
||||||
torch.Tensor(mesh_input).type_as(_components))
|
torch.Tensor(mesh_input).type_as(_components))
|
||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user