refactor(api)!: merge the new api changes into dev
This commit is contained in:
		@@ -4,7 +4,10 @@ commit = True
 | 
				
			|||||||
tag = True
 | 
					tag = True
 | 
				
			||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
 | 
					parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
 | 
				
			||||||
serialize = {major}.{minor}.{patch}
 | 
					serialize = {major}.{minor}.{patch}
 | 
				
			||||||
 | 
					message = bump: {current_version} → {new_version}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[bumpversion:file:setup.py]
 | 
					[bumpversion:file:setup.py]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[bumpversion:file:./prototorch/models/__init__.py]
 | 
					[bumpversion:file:./prototorch/models/__init__.py]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[bumpversion:file:./docs/source/conf.py]
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										17
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										17
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -128,14 +128,19 @@ dmypy.json
 | 
				
			|||||||
# Pyre type checker
 | 
					# Pyre type checker
 | 
				
			||||||
.pyre/
 | 
					.pyre/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Datasets
 | 
					 | 
				
			||||||
datasets/
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# PyTorch-Lightning
 | 
					 | 
				
			||||||
lightning_logs/
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
.vscode/
 | 
					.vscode/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Vim
 | 
				
			||||||
 | 
					*~
 | 
				
			||||||
 | 
					*.swp
 | 
				
			||||||
 | 
					*.swo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#  Pytorch Models or Weights
 | 
					#  Pytorch Models or Weights
 | 
				
			||||||
#  If necessary make exceptions for single pretrained models
 | 
					#  If necessary make exceptions for single pretrained models
 | 
				
			||||||
*.pt
 | 
					*.pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Artifacts created by ProtoTorch Models
 | 
				
			||||||
 | 
					datasets/
 | 
				
			||||||
 | 
					lightning_logs/
 | 
				
			||||||
 | 
					examples/_*.py
 | 
				
			||||||
 | 
					examples/_*.ipynb
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,6 @@
 | 
				
			|||||||
# See https://pre-commit.com for more information
 | 
					# See https://pre-commit.com for more information
 | 
				
			||||||
# See https://pre-commit.com/hooks.html for more hooks
 | 
					# See https://pre-commit.com/hooks.html for more hooks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
repos:
 | 
					repos:
 | 
				
			||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
 | 
					- repo: https://github.com/pre-commit/pre-commit-hooks
 | 
				
			||||||
  rev: v4.0.1
 | 
					  rev: v4.0.1
 | 
				
			||||||
@@ -11,7 +12,6 @@ repos:
 | 
				
			|||||||
  - id: check-ast
 | 
					  - id: check-ast
 | 
				
			||||||
  - id: check-case-conflict
 | 
					  - id: check-case-conflict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
- repo: https://github.com/myint/autoflake
 | 
					- repo: https://github.com/myint/autoflake
 | 
				
			||||||
  rev: v1.4
 | 
					  rev: v1.4
 | 
				
			||||||
  hooks:
 | 
					  hooks:
 | 
				
			||||||
@@ -23,32 +23,31 @@ repos:
 | 
				
			|||||||
  - id: isort
 | 
					  - id: isort
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- repo: https://github.com/pre-commit/mirrors-mypy
 | 
					- repo: https://github.com/pre-commit/mirrors-mypy
 | 
				
			||||||
    rev: 'v0.902'
 | 
					  rev: v0.902
 | 
				
			||||||
  hooks:
 | 
					  hooks:
 | 
				
			||||||
  - id: mypy
 | 
					  - id: mypy
 | 
				
			||||||
    files: prototorch
 | 
					    files: prototorch
 | 
				
			||||||
    additional_dependencies: [types-pkg_resources]
 | 
					    additional_dependencies: [types-pkg_resources]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- repo: https://github.com/pre-commit/mirrors-yapf
 | 
					- repo: https://github.com/pre-commit/mirrors-yapf
 | 
				
			||||||
    rev: 'v0.31.0'  # Use the sha / tag you want to point at
 | 
					  rev: v0.31.0
 | 
				
			||||||
  hooks:
 | 
					  hooks:
 | 
				
			||||||
  - id: yapf
 | 
					  - id: yapf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- repo: https://github.com/pre-commit/pygrep-hooks
 | 
					- repo: https://github.com/pre-commit/pygrep-hooks
 | 
				
			||||||
    rev: v1.9.0  # Use the ref you want to point at
 | 
					  rev: v1.9.0
 | 
				
			||||||
  hooks:
 | 
					  hooks:
 | 
				
			||||||
  - id: python-use-type-annotations
 | 
					  - id: python-use-type-annotations
 | 
				
			||||||
  - id: python-no-log-warn
 | 
					  - id: python-no-log-warn
 | 
				
			||||||
  - id: python-check-blanket-noqa
 | 
					  - id: python-check-blanket-noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
- repo: https://github.com/asottile/pyupgrade
 | 
					- repo: https://github.com/asottile/pyupgrade
 | 
				
			||||||
  rev: v2.19.4
 | 
					  rev: v2.19.4
 | 
				
			||||||
  hooks:
 | 
					  hooks:
 | 
				
			||||||
  - id: pyupgrade
 | 
					  - id: pyupgrade
 | 
				
			||||||
 | 
					
 | 
				
			||||||
-   repo: https://github.com/jorisroovers/gitlint
 | 
					- repo: https://github.com/si-cim/gitlint
 | 
				
			||||||
    rev: "v0.15.1"
 | 
					  rev: v0.15.2-unofficial
 | 
				
			||||||
  hooks:
 | 
					  hooks:
 | 
				
			||||||
  - id: gitlint
 | 
					  - id: gitlint
 | 
				
			||||||
    args: [--contrib=CT1, --ignore=B6, --msg-filename]
 | 
					    args: [--contrib=CT1, --ignore=B6, --msg-filename]
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										34
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								README.md
									
									
									
									
									
								
							@@ -20,23 +20,6 @@ pip install prototorch_models
 | 
				
			|||||||
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
 | 
					of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
 | 
				
			||||||
be available for use in your Python environment as `prototorch.models`.
 | 
					be available for use in your Python environment as `prototorch.models`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Contribution
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
This repository contains definition for [git hooks](https://githooks.com).
 | 
					 | 
				
			||||||
[Pre-commit](https://pre-commit.com) is automatically installed as development
 | 
					 | 
				
			||||||
dependency with prototorch or you can install it manually with `pip install
 | 
					 | 
				
			||||||
pre-commit`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Please install the hooks by running:
 | 
					 | 
				
			||||||
```bash
 | 
					 | 
				
			||||||
pre-commit install
 | 
					 | 
				
			||||||
pre-commit install --hook-type commit-msg
 | 
					 | 
				
			||||||
```
 | 
					 | 
				
			||||||
before creating the first commit.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
The commit will fail if the commit message does not follow the specification
 | 
					 | 
				
			||||||
provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Available models
 | 
					## Available models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### LVQ Family
 | 
					### LVQ Family
 | 
				
			||||||
@@ -103,6 +86,23 @@ To assist in the development process, you may also find it useful to install
 | 
				
			|||||||
please avoid installing Tensorflow in this environment. It is known to cause
 | 
					please avoid installing Tensorflow in this environment. It is known to cause
 | 
				
			||||||
problems with PyTorch-Lightning.**
 | 
					problems with PyTorch-Lightning.**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Contribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This repository contains definition for [git hooks](https://githooks.com).
 | 
				
			||||||
 | 
					[Pre-commit](https://pre-commit.com) is automatically installed as development
 | 
				
			||||||
 | 
					dependency with prototorch or you can install it manually with `pip install
 | 
				
			||||||
 | 
					pre-commit`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Please install the hooks by running:
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					pre-commit install
 | 
				
			||||||
 | 
					pre-commit install --hook-type commit-msg
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					before creating the first commit.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The commit will fail if the commit message does not follow the specification
 | 
				
			||||||
 | 
					provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## FAQ
 | 
					## FAQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### How do I update the plugin?
 | 
					### How do I update the plugin?
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# The full version, including alpha/beta/rc tags
 | 
					# The full version, including alpha/beta/rc tags
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
release = "0.4.4"
 | 
					release = "0.1.8"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# -- General configuration ---------------------------------------------------
 | 
					# -- General configuration ---------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
@@ -24,14 +23,18 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        distribution=[2, 2, 2],
 | 
					        distribution=[1, 0, 3],
 | 
				
			||||||
        proto_lr=0.1,
 | 
					        margin=0.1,
 | 
				
			||||||
 | 
					        proto_lr=0.01,
 | 
				
			||||||
 | 
					        bb_lr=0.01,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.CBC(
 | 
					    model = pt.models.CBC(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        prototype_initializer=pt.components.SSI(train_ds, noise=0.01),
 | 
					        components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
 | 
				
			||||||
 | 
					        reasonings_iniitializer=pt.initializers.
 | 
				
			||||||
 | 
					        PurePositiveReasoningsInitializer(),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
@@ -37,7 +36,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.CELVQ(
 | 
					    model = pt.models.CELVQ(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        prototype_initializer=pt.components.Ones(2, scale=3),
 | 
					        prototypes_initializer=pt.initializers.FVCI(2, 3.0),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Compute intermediate input and output sizes
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,12 +2,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
@@ -24,7 +23,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        distribution={
 | 
					        distribution={
 | 
				
			||||||
            "num_classes": 3,
 | 
					            "num_classes": 3,
 | 
				
			||||||
            "prototypes_per_class": 4
 | 
					            "per_class": 4
 | 
				
			||||||
        },
 | 
					        },
 | 
				
			||||||
        lr=0.01,
 | 
					        lr=0.01,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
@@ -33,7 +32,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    model = pt.models.GLVQ(
 | 
					    model = pt.models.GLVQ(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        optimizer=torch.optim.Adam,
 | 
					        optimizer=torch.optim.Adam,
 | 
				
			||||||
        prototype_initializer=pt.components.SMI(train_ds),
 | 
					        prototypes_initializer=pt.initializers.SMCI(train_ds),
 | 
				
			||||||
        lr_scheduler=ExponentialLR,
 | 
					        lr_scheduler=ExponentialLR,
 | 
				
			||||||
        lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
 | 
					        lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
@@ -26,7 +25,6 @@ if __name__ == "__main__":
 | 
				
			|||||||
        distribution=(num_classes, prototypes_per_class),
 | 
					        distribution=(num_classes, prototypes_per_class),
 | 
				
			||||||
        transfer_function="swish_beta",
 | 
					        transfer_function="swish_beta",
 | 
				
			||||||
        transfer_beta=10.0,
 | 
					        transfer_beta=10.0,
 | 
				
			||||||
        # lr=0.1,
 | 
					 | 
				
			||||||
        proto_lr=0.1,
 | 
					        proto_lr=0.1,
 | 
				
			||||||
        bb_lr=0.1,
 | 
					        bb_lr=0.1,
 | 
				
			||||||
        input_dim=2,
 | 
					        input_dim=2,
 | 
				
			||||||
@@ -37,7 +35,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    model = pt.models.GMLVQ(
 | 
					    model = pt.models.GMLVQ(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        optimizer=torch.optim.Adam,
 | 
					        optimizer=torch.optim.Adam,
 | 
				
			||||||
        prototype_initializer=pt.components.SSI(train_ds, noise=1e-2),
 | 
					        prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
@@ -47,12 +45,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
        block=False,
 | 
					        block=False,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    pruning = pt.models.PruneLoserPrototypes(
 | 
					    pruning = pt.models.PruneLoserPrototypes(
 | 
				
			||||||
        threshold=0.02,
 | 
					        threshold=0.01,
 | 
				
			||||||
        idle_epochs=10,
 | 
					        idle_epochs=10,
 | 
				
			||||||
        prune_quota_per_epoch=5,
 | 
					        prune_quota_per_epoch=5,
 | 
				
			||||||
        frequency=2,
 | 
					        frequency=5,
 | 
				
			||||||
        replace=True,
 | 
					        replace=True,
 | 
				
			||||||
        initializer=pt.components.SSI(train_ds, noise=1e-2),
 | 
					        prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
 | 
				
			||||||
        verbose=True,
 | 
					        verbose=True,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    es = pl.callbacks.EarlyStopping(
 | 
					    es = pl.callbacks.EarlyStopping(
 | 
				
			||||||
@@ -68,7 +66,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
        args,
 | 
					        args,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            # es,
 | 
					            # es, # FIXME
 | 
				
			||||||
            pruning,
 | 
					            pruning,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
        terminate_on_nan=True,
 | 
					        terminate_on_nan=True,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,59 +0,0 @@
 | 
				
			|||||||
"""GLVQ example using the Iris dataset."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import argparse
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    # Command-line arguments
 | 
					 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					 | 
				
			||||||
    args = parser.parse_args()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Dataset
 | 
					 | 
				
			||||||
    train_ds = pt.datasets.Iris()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Dataloaders
 | 
					 | 
				
			||||||
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Hyperparameters
 | 
					 | 
				
			||||||
    hparams = dict(
 | 
					 | 
				
			||||||
        input_dim=4,
 | 
					 | 
				
			||||||
        latent_dim=3,
 | 
					 | 
				
			||||||
        distribution={
 | 
					 | 
				
			||||||
            "num_classes": 3,
 | 
					 | 
				
			||||||
            "prototypes_per_class": 2
 | 
					 | 
				
			||||||
        },
 | 
					 | 
				
			||||||
        proto_lr=0.0005,
 | 
					 | 
				
			||||||
        bb_lr=0.0005,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Initialize the model
 | 
					 | 
				
			||||||
    model = pt.models.GMLVQ(
 | 
					 | 
				
			||||||
        hparams,
 | 
					 | 
				
			||||||
        optimizer=torch.optim.Adam,
 | 
					 | 
				
			||||||
        prototype_initializer=pt.components.SSI(train_ds),
 | 
					 | 
				
			||||||
        lr_scheduler=ExponentialLR,
 | 
					 | 
				
			||||||
        lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
 | 
					 | 
				
			||||||
        omega_initializer=pt.components.PCA(train_ds.data)
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Compute intermediate input and output sizes
 | 
					 | 
				
			||||||
    #model.example_input_array = torch.zeros(4, 2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Callbacks
 | 
					 | 
				
			||||||
    vis = pt.models.VisGMLVQ2D(data=train_ds, border=0.1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Setup trainer
 | 
					 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					 | 
				
			||||||
        args,
 | 
					 | 
				
			||||||
        callbacks=[vis],
 | 
					 | 
				
			||||||
        weights_summary="full",
 | 
					 | 
				
			||||||
        accelerator="ddp",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Training loop
 | 
					 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					 | 
				
			||||||
@@ -2,11 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
@@ -30,7 +29,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.GrowingNeuralGas(
 | 
					    model = pt.models.GrowingNeuralGas(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        prototype_initializer=pt.components.Zeros(2),
 | 
					        prototypes_initializer=pt.initializers.ZCI(2),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Compute intermediate input and output sizes
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,25 +2,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					from prototorch.utils.colors import hex_to_rgb
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def hex_to_rgb(hex_values):
 | 
					 | 
				
			||||||
    for v in hex_values:
 | 
					 | 
				
			||||||
        v = v.lstrip('#')
 | 
					 | 
				
			||||||
        lv = len(v)
 | 
					 | 
				
			||||||
        c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)]
 | 
					 | 
				
			||||||
        yield c
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def rgb_to_hex(rgb_values):
 | 
					 | 
				
			||||||
    for v in rgb_values:
 | 
					 | 
				
			||||||
        c = "%02x%02x%02x" % tuple(v)
 | 
					 | 
				
			||||||
        yield c
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Vis2DColorSOM(pl.Callback):
 | 
					class Vis2DColorSOM(pl.Callback):
 | 
				
			||||||
@@ -93,7 +79,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.KohonenSOM(
 | 
					    model = pt.models.KohonenSOM(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        prototype_initializer=pt.components.Random(3),
 | 
					        prototypes_initializer=pt.initializers.RNCI(3),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Compute intermediate input and output sizes
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,23 +2,22 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					 | 
				
			||||||
    train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Reproducibility
 | 
					    # Reproducibility
 | 
				
			||||||
    pl.utilities.seed.seed_everything(seed=2)
 | 
					    pl.utilities.seed.seed_everything(seed=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataset
 | 
				
			||||||
 | 
					    train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataloaders
 | 
					    # Dataloaders
 | 
				
			||||||
    train_loader = torch.utils.data.DataLoader(train_ds,
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds,
 | 
				
			||||||
                                               batch_size=256,
 | 
					                                               batch_size=256,
 | 
				
			||||||
@@ -32,8 +31,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.LGMLVQ(hparams,
 | 
					    model = pt.models.LGMLVQ(
 | 
				
			||||||
                             prototype_initializer=pt.components.SMI(train_ds))
 | 
					        hparams,
 | 
				
			||||||
 | 
					        prototypes_initializer=pt.initializers.SMCI(train_ds),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Compute intermediate input and output sizes
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
    model.example_input_array = torch.zeros(4, 2)
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,11 +3,10 @@
 | 
				
			|||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import matplotlib.pyplot as plt
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def plot_matrix(matrix):
 | 
					def plot_matrix(matrix):
 | 
				
			||||||
    title = "Lambda matrix"
 | 
					    title = "Lambda matrix"
 | 
				
			||||||
@@ -40,20 +39,19 @@ if __name__ == "__main__":
 | 
				
			|||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        distribution={
 | 
					        distribution={
 | 
				
			||||||
            "num_classes": 2,
 | 
					            "num_classes": 2,
 | 
				
			||||||
            "prototypes_per_class": 1
 | 
					            "per_class": 1,
 | 
				
			||||||
        },
 | 
					        },
 | 
				
			||||||
        input_dim=100,
 | 
					        input_dim=100,
 | 
				
			||||||
        latent_dim=2,
 | 
					        latent_dim=2,
 | 
				
			||||||
        proto_lr=0.0001,
 | 
					        proto_lr=0.001,
 | 
				
			||||||
        bb_lr=0.0001,
 | 
					        bb_lr=0.001,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.SiameseGMLVQ(
 | 
					    model = pt.models.SiameseGMLVQ(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        # optimizer=torch.optim.SGD,
 | 
					 | 
				
			||||||
        optimizer=torch.optim.Adam,
 | 
					        optimizer=torch.optim.Adam,
 | 
				
			||||||
        prototype_initializer=pt.components.SMI(train_ds),
 | 
					        prototypes_initializer=pt.initializers.SMCI(train_ds),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Summary
 | 
					    # Summary
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Backbone(torch.nn.Module):
 | 
					class Backbone(torch.nn.Module):
 | 
				
			||||||
    def __init__(self, input_size=4, hidden_size=10, latent_size=2):
 | 
					    def __init__(self, input_size=4, hidden_size=10, latent_size=2):
 | 
				
			||||||
@@ -41,7 +40,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        distribution=[1, 2, 2],
 | 
					        distribution=[3, 4, 5],
 | 
				
			||||||
        proto_lr=0.001,
 | 
					        proto_lr=0.001,
 | 
				
			||||||
        bb_lr=0.001,
 | 
					        bb_lr=0.001,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
@@ -52,7 +51,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.LVQMLN(
 | 
					    model = pt.models.LVQMLN(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        prototype_initializer=pt.components.SSI(train_ds, transform=backbone),
 | 
					        prototypes_initializer=pt.initializers.SSCI(
 | 
				
			||||||
 | 
					            train_ds,
 | 
				
			||||||
 | 
					            transform=backbone,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
        backbone=backbone,
 | 
					        backbone=backbone,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -67,11 +69,21 @@ if __name__ == "__main__":
 | 
				
			|||||||
        resolution=500,
 | 
					        resolution=500,
 | 
				
			||||||
        axis_off=True,
 | 
					        axis_off=True,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    pruning = pt.models.PruneLoserPrototypes(
 | 
				
			||||||
 | 
					        threshold=0.01,
 | 
				
			||||||
 | 
					        idle_epochs=20,
 | 
				
			||||||
 | 
					        prune_quota_per_epoch=2,
 | 
				
			||||||
 | 
					        frequency=10,
 | 
				
			||||||
 | 
					        verbose=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        args,
 | 
					        args,
 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[
 | 
				
			||||||
 | 
					            vis,
 | 
				
			||||||
 | 
					            pruning,
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,9 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torchvision.transforms import Lambda
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
@@ -28,19 +26,17 @@ if __name__ == "__main__":
 | 
				
			|||||||
        distribution=[2, 2, 3],
 | 
					        distribution=[2, 2, 3],
 | 
				
			||||||
        proto_lr=0.05,
 | 
					        proto_lr=0.05,
 | 
				
			||||||
        lambd=0.1,
 | 
					        lambd=0.1,
 | 
				
			||||||
 | 
					        variance=1.0,
 | 
				
			||||||
        input_dim=2,
 | 
					        input_dim=2,
 | 
				
			||||||
        latent_dim=2,
 | 
					        latent_dim=2,
 | 
				
			||||||
        bb_lr=0.01,
 | 
					        bb_lr=0.01,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.probabilistic.PLVQ(
 | 
					    model = pt.models.RSLVQ(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        optimizer=torch.optim.Adam,
 | 
					        optimizer=torch.optim.Adam,
 | 
				
			||||||
        # prototype_initializer=pt.components.SMI(train_ds),
 | 
					        prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
 | 
				
			||||||
        prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
 | 
					 | 
				
			||||||
        # prototype_initializer=pt.components.Zeros(2),
 | 
					 | 
				
			||||||
        # prototype_initializer=pt.components.Ones(2, scale=2.0),
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Compute intermediate input and output sizes
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
@@ -50,7 +46,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    print(model)
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = pt.models.VisSiameseGLVQ2D(data=train_ds)
 | 
					    vis = pt.models.VisGLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Backbone(torch.nn.Module):
 | 
					class Backbone(torch.nn.Module):
 | 
				
			||||||
    def __init__(self, input_size=4, hidden_size=10, latent_size=2):
 | 
					    def __init__(self, input_size=4, hidden_size=10, latent_size=2):
 | 
				
			||||||
@@ -52,7 +51,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.SiameseGLVQ(
 | 
					    model = pt.models.SiameseGLVQ(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        prototype_initializer=pt.components.SMI(train_ds),
 | 
					        prototypes_initializer=pt.initializers.SMCI(train_ds),
 | 
				
			||||||
        backbone=backbone,
 | 
					        backbone=backbone,
 | 
				
			||||||
        both_path_gradients=False,
 | 
					        both_path_gradients=False,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										84
									
								
								examples/warm_starting.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								examples/warm_starting.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,84 @@
 | 
				
			|||||||
 | 
					"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Prepare the data
 | 
				
			||||||
 | 
					    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
				
			||||||
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initialize the gng
 | 
				
			||||||
 | 
					    gng = pt.models.GrowingNeuralGas(
 | 
				
			||||||
 | 
					        hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
 | 
				
			||||||
 | 
					        prototypes_initializer=pt.initializers.ZCI(2),
 | 
				
			||||||
 | 
					        lr_scheduler=ExponentialLR,
 | 
				
			||||||
 | 
					        lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    es = pl.callbacks.EarlyStopping(
 | 
				
			||||||
 | 
					        monitor="loss",
 | 
				
			||||||
 | 
					        min_delta=0.001,
 | 
				
			||||||
 | 
					        patience=20,
 | 
				
			||||||
 | 
					        mode="min",
 | 
				
			||||||
 | 
					        verbose=False,
 | 
				
			||||||
 | 
					        check_on_train_epoch_end=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setup trainer for GNG
 | 
				
			||||||
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
 | 
					        max_epochs=200,
 | 
				
			||||||
 | 
					        callbacks=[es],
 | 
				
			||||||
 | 
					        weights_summary=None,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Training loop
 | 
				
			||||||
 | 
					    trainer.fit(gng, train_loader)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    hparams = dict(
 | 
				
			||||||
 | 
					        distribution=[],
 | 
				
			||||||
 | 
					        lr=0.01,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Warm-start prototypes
 | 
				
			||||||
 | 
					    knn = pt.models.KNN(dict(k=1), data=train_ds)
 | 
				
			||||||
 | 
					    prototypes = gng.prototypes
 | 
				
			||||||
 | 
					    plabels = knn.predict(prototypes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initialize the model
 | 
				
			||||||
 | 
					    model = pt.models.GLVQ(
 | 
				
			||||||
 | 
					        hparams,
 | 
				
			||||||
 | 
					        optimizer=torch.optim.Adam,
 | 
				
			||||||
 | 
					        prototypes_initializer=pt.initializers.LCI(prototypes),
 | 
				
			||||||
 | 
					        labels_initializer=pt.initializers.LLI(plabels),
 | 
				
			||||||
 | 
					        lr_scheduler=ExponentialLR,
 | 
				
			||||||
 | 
					        lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    vis = pt.models.VisGLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setup trainer
 | 
				
			||||||
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					        args,
 | 
				
			||||||
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					        weights_summary="full",
 | 
				
			||||||
 | 
					        accelerator="ddp",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Training loop
 | 
				
			||||||
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
@@ -4,8 +4,19 @@ from importlib.metadata import PackageNotFoundError, version
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
 | 
					from .callbacks import PrototypeConvergence, PruneLoserPrototypes
 | 
				
			||||||
from .cbc import CBC, ImageCBC
 | 
					from .cbc import CBC, ImageCBC
 | 
				
			||||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
 | 
					from .glvq import (
 | 
				
			||||||
                   ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
 | 
					    GLVQ,
 | 
				
			||||||
 | 
					    GLVQ1,
 | 
				
			||||||
 | 
					    GLVQ21,
 | 
				
			||||||
 | 
					    GMLVQ,
 | 
				
			||||||
 | 
					    GRLVQ,
 | 
				
			||||||
 | 
					    LGMLVQ,
 | 
				
			||||||
 | 
					    LVQMLN,
 | 
				
			||||||
 | 
					    ImageGLVQ,
 | 
				
			||||||
 | 
					    ImageGMLVQ,
 | 
				
			||||||
 | 
					    SiameseGLVQ,
 | 
				
			||||||
 | 
					    SiameseGMLVQ,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from .knn import KNN
 | 
					from .knn import KNN
 | 
				
			||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
 | 
					from .lvq import LVQ1, LVQ21, MedianLVQ
 | 
				
			||||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
 | 
					from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,9 +5,13 @@ from typing import Final, final
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
from prototorch.components import Components, LabeledComponents
 | 
					
 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance
 | 
					from ..core.competitions import WTAC
 | 
				
			||||||
from prototorch.modules import WTAC, LambdaLayer
 | 
					from ..core.components import Components, LabeledComponents
 | 
				
			||||||
 | 
					from ..core.distances import euclidean_distance
 | 
				
			||||||
 | 
					from ..core.initializers import LabelsInitializer
 | 
				
			||||||
 | 
					from ..core.pooling import stratified_min_pooling
 | 
				
			||||||
 | 
					from ..nn.wrappers import LambdaLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ProtoTorchMixin(object):
 | 
					class ProtoTorchMixin(object):
 | 
				
			||||||
@@ -85,13 +89,11 @@ class UnsupervisedPrototypeModel(PrototypeModel):
 | 
				
			|||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Layers
 | 
					        # Layers
 | 
				
			||||||
        prototype_initializer = kwargs.get("prototype_initializer", None)
 | 
					        prototypes_initializer = kwargs.get("prototypes_initializer", None)
 | 
				
			||||||
        initialized_prototypes = kwargs.get("initialized_prototypes", None)
 | 
					        if prototypes_initializer is not None:
 | 
				
			||||||
        if prototype_initializer is not None or initialized_prototypes is not None:
 | 
					 | 
				
			||||||
            self.proto_layer = Components(
 | 
					            self.proto_layer = Components(
 | 
				
			||||||
                self.hparams.num_prototypes,
 | 
					                self.hparams.num_prototypes,
 | 
				
			||||||
                initializer=prototype_initializer,
 | 
					                initializer=prototypes_initializer,
 | 
				
			||||||
                initialized_components=initialized_prototypes,
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compute_distances(self, x):
 | 
					    def compute_distances(self, x):
 | 
				
			||||||
@@ -109,23 +111,24 @@ class SupervisedPrototypeModel(PrototypeModel):
 | 
				
			|||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Layers
 | 
					        # Layers
 | 
				
			||||||
        prototype_initializer = kwargs.get("prototype_initializer", None)
 | 
					        prototypes_initializer = kwargs.get("prototypes_initializer", None)
 | 
				
			||||||
        initialized_prototypes = kwargs.get("initialized_prototypes", None)
 | 
					        labels_initializer = kwargs.get("labels_initializer",
 | 
				
			||||||
        if prototype_initializer is not None or initialized_prototypes is not None:
 | 
					                                        LabelsInitializer())
 | 
				
			||||||
 | 
					        if prototypes_initializer is not None:
 | 
				
			||||||
            self.proto_layer = LabeledComponents(
 | 
					            self.proto_layer = LabeledComponents(
 | 
				
			||||||
                distribution=self.hparams.distribution,
 | 
					                distribution=self.hparams.distribution,
 | 
				
			||||||
                initializer=prototype_initializer,
 | 
					                components_initializer=prototypes_initializer,
 | 
				
			||||||
                initialized_components=initialized_prototypes,
 | 
					                labels_initializer=labels_initializer,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        self.competition_layer = WTAC()
 | 
					        self.competition_layer = WTAC()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def prototype_labels(self):
 | 
					    def prototype_labels(self):
 | 
				
			||||||
        return self.proto_layer.component_labels.detach().cpu()
 | 
					        return self.proto_layer.labels.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def num_classes(self):
 | 
					    def num_classes(self):
 | 
				
			||||||
        return len(self.proto_layer.distribution)
 | 
					        return self.proto_layer.num_classes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compute_distances(self, x):
 | 
					    def compute_distances(self, x):
 | 
				
			||||||
        protos, _ = self.proto_layer()
 | 
					        protos, _ = self.proto_layer()
 | 
				
			||||||
@@ -134,15 +137,14 @@ class SupervisedPrototypeModel(PrototypeModel):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        distances = self.compute_distances(x)
 | 
					        distances = self.compute_distances(x)
 | 
				
			||||||
        y_pred = self.predict_from_distances(distances)
 | 
					        plabels = self.proto_layer.labels
 | 
				
			||||||
        # TODO
 | 
					        winning = stratified_min_pooling(distances, plabels)
 | 
				
			||||||
        y_pred = torch.eye(self.num_classes, device=self.device)[
 | 
					        y_pred = torch.nn.functional.softmin(winning)
 | 
				
			||||||
            y_pred.long()]  # depends on labels {0,...,num_classes}
 | 
					 | 
				
			||||||
        return y_pred
 | 
					        return y_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def predict_from_distances(self, distances):
 | 
					    def predict_from_distances(self, distances):
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
            plabels = self.proto_layer.component_labels
 | 
					            plabels = self.proto_layer.labels
 | 
				
			||||||
            y_pred = self.competition_layer(distances, plabels)
 | 
					            y_pred = self.competition_layer(distances, plabels)
 | 
				
			||||||
        return y_pred
 | 
					        return y_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,8 +4,9 @@ import logging
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.components import Components
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.components import Components
 | 
				
			||||||
 | 
					from ..core.initializers import LiteralCompInitializer
 | 
				
			||||||
from .extras import ConnectionTopology
 | 
					from .extras import ConnectionTopology
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -16,7 +17,7 @@ class PruneLoserPrototypes(pl.Callback):
 | 
				
			|||||||
                 prune_quota_per_epoch=-1,
 | 
					                 prune_quota_per_epoch=-1,
 | 
				
			||||||
                 frequency=1,
 | 
					                 frequency=1,
 | 
				
			||||||
                 replace=False,
 | 
					                 replace=False,
 | 
				
			||||||
                 initializer=None,
 | 
					                 prototypes_initializer=None,
 | 
				
			||||||
                 verbose=False):
 | 
					                 verbose=False):
 | 
				
			||||||
        self.threshold = threshold  # minimum win ratio
 | 
					        self.threshold = threshold  # minimum win ratio
 | 
				
			||||||
        self.idle_epochs = idle_epochs  # epochs to wait before pruning
 | 
					        self.idle_epochs = idle_epochs  # epochs to wait before pruning
 | 
				
			||||||
@@ -24,7 +25,7 @@ class PruneLoserPrototypes(pl.Callback):
 | 
				
			|||||||
        self.frequency = frequency
 | 
					        self.frequency = frequency
 | 
				
			||||||
        self.replace = replace
 | 
					        self.replace = replace
 | 
				
			||||||
        self.verbose = verbose
 | 
					        self.verbose = verbose
 | 
				
			||||||
        self.initializer = initializer
 | 
					        self.prototypes_initializer = prototypes_initializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        if (trainer.current_epoch + 1) < self.idle_epochs:
 | 
					        if (trainer.current_epoch + 1) < self.idle_epochs:
 | 
				
			||||||
@@ -55,8 +56,9 @@ class PruneLoserPrototypes(pl.Callback):
 | 
				
			|||||||
                if self.verbose:
 | 
					                if self.verbose:
 | 
				
			||||||
                    print(f"Re-adding pruned prototypes...")
 | 
					                    print(f"Re-adding pruned prototypes...")
 | 
				
			||||||
                    print(f"{distribution=}")
 | 
					                    print(f"{distribution=}")
 | 
				
			||||||
                pl_module.add_prototypes(distribution=distribution,
 | 
					                pl_module.add_prototypes(
 | 
				
			||||||
                                         initializer=self.initializer)
 | 
					                    distribution=distribution,
 | 
				
			||||||
 | 
					                    components_initializer=self.prototypes_initializer)
 | 
				
			||||||
            new_num_protos = pl_module.num_prototypes
 | 
					            new_num_protos = pl_module.num_prototypes
 | 
				
			||||||
            if self.verbose:
 | 
					            if self.verbose:
 | 
				
			||||||
                print(f"`num_prototypes` changed from {cur_num_protos} "
 | 
					                print(f"`num_prototypes` changed from {cur_num_protos} "
 | 
				
			||||||
@@ -116,7 +118,8 @@ class GNGCallback(pl.Callback):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            # Add component
 | 
					            # Add component
 | 
				
			||||||
            pl_module.proto_layer.add_components(
 | 
					            pl_module.proto_layer.add_components(
 | 
				
			||||||
                initialized_components=new_component.unsqueeze(0))
 | 
					                None,
 | 
				
			||||||
 | 
					                initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Adjust Topology
 | 
					            # Adjust Topology
 | 
				
			||||||
            topology.add_prototype()
 | 
					            topology.add_prototype()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,49 +1,54 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.competitions import CBCC
 | 
				
			||||||
 | 
					from ..core.components import ReasoningComponents
 | 
				
			||||||
 | 
					from ..core.initializers import RandomReasoningsInitializer
 | 
				
			||||||
 | 
					from ..core.losses import MarginLoss
 | 
				
			||||||
 | 
					from ..core.similarities import euclidean_similarity
 | 
				
			||||||
 | 
					from ..nn.wrappers import LambdaLayer
 | 
				
			||||||
from .abstract import ImagePrototypesMixin
 | 
					from .abstract import ImagePrototypesMixin
 | 
				
			||||||
from .extras import (CosineSimilarity, MarginLoss, ReasoningLayer,
 | 
					 | 
				
			||||||
                     euclidean_similarity, rescaled_cosine_similarity,
 | 
					 | 
				
			||||||
                     shift_activation)
 | 
					 | 
				
			||||||
from .glvq import SiameseGLVQ
 | 
					from .glvq import SiameseGLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CBC(SiameseGLVQ):
 | 
					class CBC(SiameseGLVQ):
 | 
				
			||||||
    """Classification-By-Components."""
 | 
					    """Classification-By-Components."""
 | 
				
			||||||
    def __init__(self, hparams, margin=0.1, **kwargs):
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
        self.margin = margin
 | 
					 | 
				
			||||||
        self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
 | 
					 | 
				
			||||||
        num_components = self.components.shape[0]
 | 
					 | 
				
			||||||
        self.reasoning_layer = ReasoningLayer(num_components=num_components,
 | 
					 | 
				
			||||||
                                              num_classes=self.num_classes)
 | 
					 | 
				
			||||||
        self.component_layer = self.proto_layer
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					        similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
 | 
				
			||||||
    def components(self):
 | 
					        components_initializer = kwargs.get("components_initializer", None)
 | 
				
			||||||
        return self.prototypes
 | 
					        reasonings_initializer = kwargs.get("reasonings_initializer",
 | 
				
			||||||
 | 
					                                            RandomReasoningsInitializer())
 | 
				
			||||||
 | 
					        self.components_layer = ReasoningComponents(
 | 
				
			||||||
 | 
					            self.hparams.distribution,
 | 
				
			||||||
 | 
					            components_initializer=components_initializer,
 | 
				
			||||||
 | 
					            reasonings_initializer=reasonings_initializer,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.similarity_layer = LambdaLayer(similarity_fn)
 | 
				
			||||||
 | 
					        self.competition_layer = CBCC()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					        # Namespace hook
 | 
				
			||||||
    def reasonings(self):
 | 
					        self.proto_layer = self.components_layer
 | 
				
			||||||
        return self.reasoning_layer.reasonings.cpu()
 | 
					
 | 
				
			||||||
 | 
					        self.loss = MarginLoss(self.hparams.margin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        components, _ = self.component_layer()
 | 
					        components, reasonings = self.components_layer()
 | 
				
			||||||
        latent_x = self.backbone(x)
 | 
					        latent_x = self.backbone(x)
 | 
				
			||||||
        self.backbone.requires_grad_(self.both_path_gradients)
 | 
					        self.backbone.requires_grad_(self.both_path_gradients)
 | 
				
			||||||
        latent_components = self.backbone(components)
 | 
					        latent_components = self.backbone(components)
 | 
				
			||||||
        self.backbone.requires_grad_(True)
 | 
					        self.backbone.requires_grad_(True)
 | 
				
			||||||
        detections = self.similarity_fn(latent_x, latent_components)
 | 
					        detections = self.similarity_layer(latent_x, latent_components)
 | 
				
			||||||
        probs = self.reasoning_layer(detections)
 | 
					        probs = self.competition_layer(detections, reasonings)
 | 
				
			||||||
        return probs
 | 
					        return probs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def shared_step(self, batch, batch_idx, optimizer_idx=None):
 | 
					    def shared_step(self, batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
        x, y = batch
 | 
					        x, y = batch
 | 
				
			||||||
        # x = x.view(x.size(0), -1)
 | 
					 | 
				
			||||||
        y_pred = self(x)
 | 
					        y_pred = self(x)
 | 
				
			||||||
        num_classes = self.reasoning_layer.num_classes
 | 
					        num_classes = self.num_classes
 | 
				
			||||||
        y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
 | 
					        y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
 | 
				
			||||||
        loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
 | 
					        loss = self.loss(y_pred, y_true).mean(dim=0)
 | 
				
			||||||
        return y_pred, loss
 | 
					        return y_pred, loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
					    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
@@ -70,7 +75,3 @@ class ImageCBC(ImagePrototypesMixin, CBC):
 | 
				
			|||||||
    """CBC model that constrains the components to the range [0, 1] by
 | 
					    """CBC model that constrains the components to the range [0, 1] by
 | 
				
			||||||
    clamping after updates.
 | 
					    clamping after updates.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, hparams, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(hparams, **kwargs)
 | 
					 | 
				
			||||||
        # Namespace hook
 | 
					 | 
				
			||||||
        self.proto_layer = self.component_layer
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,23 +5,32 @@ Modules not yet available in prototorch go here temporarily.
 | 
				
			|||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance
 | 
					
 | 
				
			||||||
from prototorch.functions.similarities import cosine_similarity
 | 
					from ..core.similarities import gaussian
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rescaled_cosine_similarity(x, y):
 | 
					def rank_scaled_gaussian(distances, lambd):
 | 
				
			||||||
    """Cosine Similarity rescaled to [0, 1]."""
 | 
					    order = torch.argsort(distances, dim=1)
 | 
				
			||||||
    similarities = cosine_similarity(x, y)
 | 
					    ranks = torch.argsort(order, dim=1)
 | 
				
			||||||
    return (similarities + 1.0) / 2.0
 | 
					    return torch.exp(-torch.exp(-ranks / lambd) * distances)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def shift_activation(x):
 | 
					class GaussianPrior(torch.nn.Module):
 | 
				
			||||||
    return (x + 1.0) / 2.0
 | 
					    def __init__(self, variance):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.variance = variance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, distances):
 | 
				
			||||||
 | 
					        return gaussian(distances, self.variance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def euclidean_similarity(x, y, variance=1.0):
 | 
					class RankScaledGaussianPrior(torch.nn.Module):
 | 
				
			||||||
    d = euclidean_distance(x, y)
 | 
					    def __init__(self, lambd):
 | 
				
			||||||
    return torch.exp(-(d * d) / (2 * variance))
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.lambd = lambd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, distances):
 | 
				
			||||||
 | 
					        return rank_scaled_gaussian(distances, self.lambd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ConnectionTopology(torch.nn.Module):
 | 
					class ConnectionTopology(torch.nn.Module):
 | 
				
			||||||
@@ -79,64 +88,3 @@ class ConnectionTopology(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def extra_repr(self):
 | 
					    def extra_repr(self):
 | 
				
			||||||
        return f"(agelimit): ({self.agelimit})"
 | 
					        return f"(agelimit): ({self.agelimit})"
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class CosineSimilarity(torch.nn.Module):
 | 
					 | 
				
			||||||
    def __init__(self, activation=shift_activation):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.activation = activation
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, x, y):
 | 
					 | 
				
			||||||
        epsilon = torch.finfo(x.dtype).eps
 | 
					 | 
				
			||||||
        normed_x = (x / x.pow(2).sum(dim=tuple(range(
 | 
					 | 
				
			||||||
            1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
 | 
					 | 
				
			||||||
                start_dim=1)
 | 
					 | 
				
			||||||
        normed_y = (y / y.pow(2).sum(dim=tuple(range(
 | 
					 | 
				
			||||||
            1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
 | 
					 | 
				
			||||||
                start_dim=1)
 | 
					 | 
				
			||||||
        # normed_x = (x / torch.linalg.norm(x, dim=1))
 | 
					 | 
				
			||||||
        diss = torch.inner(normed_x, normed_y)
 | 
					 | 
				
			||||||
        return self.activation(diss)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class MarginLoss(torch.nn.modules.loss._Loss):
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 margin=0.3,
 | 
					 | 
				
			||||||
                 size_average=None,
 | 
					 | 
				
			||||||
                 reduce=None,
 | 
					 | 
				
			||||||
                 reduction="mean"):
 | 
					 | 
				
			||||||
        super().__init__(size_average, reduce, reduction)
 | 
					 | 
				
			||||||
        self.margin = margin
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, input_, target):
 | 
					 | 
				
			||||||
        dp = torch.sum(target * input_, dim=-1)
 | 
					 | 
				
			||||||
        dm = torch.max(input_ - target, dim=-1).values
 | 
					 | 
				
			||||||
        return torch.nn.functional.relu(dm - dp + self.margin)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ReasoningLayer(torch.nn.Module):
 | 
					 | 
				
			||||||
    def __init__(self, num_components, num_classes, num_replicas=1):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.num_replicas = num_replicas
 | 
					 | 
				
			||||||
        self.num_classes = num_classes
 | 
					 | 
				
			||||||
        probabilities_init = torch.zeros(2, 1, num_components,
 | 
					 | 
				
			||||||
                                         self.num_classes)
 | 
					 | 
				
			||||||
        probabilities_init.uniform_(0.4, 0.6)
 | 
					 | 
				
			||||||
        # TODO Use `self.register_parameter("param", Paramater(param))` instead
 | 
					 | 
				
			||||||
        self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def reasonings(self):
 | 
					 | 
				
			||||||
        pk = self.reasoning_probabilities[0]
 | 
					 | 
				
			||||||
        nk = (1 - pk) * self.reasoning_probabilities[1]
 | 
					 | 
				
			||||||
        ik = 1 - pk - nk
 | 
					 | 
				
			||||||
        img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
 | 
					 | 
				
			||||||
        return img.unsqueeze(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, detections):
 | 
					 | 
				
			||||||
        pk = self.reasoning_probabilities[0].clamp(0, 1)
 | 
					 | 
				
			||||||
        nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
 | 
					 | 
				
			||||||
        numerator = (detections @ (pk - nk)) + nk.sum(1)
 | 
					 | 
				
			||||||
        probs = numerator / (pk + nk).sum(1)
 | 
					 | 
				
			||||||
        probs = probs.squeeze(0)
 | 
					 | 
				
			||||||
        return probs
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,16 +1,14 @@
 | 
				
			|||||||
"""Models based on the GLVQ framework."""
 | 
					"""Models based on the GLVQ framework."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.functions.activations import get_activation
 | 
					 | 
				
			||||||
from prototorch.functions.competitions import wtac
 | 
					 | 
				
			||||||
from prototorch.functions.distances import (lomega_distance, omega_distance,
 | 
					 | 
				
			||||||
                                            squared_euclidean_distance)
 | 
					 | 
				
			||||||
from prototorch.functions.helper import get_flat
 | 
					 | 
				
			||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
					 | 
				
			||||||
from prototorch.components import LinearMapping
 | 
					 | 
				
			||||||
from prototorch.modules import LambdaLayer, LossLayer
 | 
					 | 
				
			||||||
from torch.nn.parameter import Parameter
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.competitions import wtac
 | 
				
			||||||
 | 
					from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
 | 
				
			||||||
 | 
					from ..core.initializers import EyeTransformInitializer
 | 
				
			||||||
 | 
					from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
				
			||||||
 | 
					from ..nn.activations import get_activation
 | 
				
			||||||
 | 
					from ..nn.wrappers import LambdaLayer, LossLayer
 | 
				
			||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
					from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -30,9 +28,6 @@ class GLVQ(SupervisedPrototypeModel):
 | 
				
			|||||||
        # Loss
 | 
					        # Loss
 | 
				
			||||||
        self.loss = LossLayer(glvq_loss)
 | 
					        self.loss = LossLayer(glvq_loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Prototype metrics
 | 
					 | 
				
			||||||
        self.initialize_prototype_win_ratios()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def initialize_prototype_win_ratios(self):
 | 
					    def initialize_prototype_win_ratios(self):
 | 
				
			||||||
        self.register_buffer(
 | 
					        self.register_buffer(
 | 
				
			||||||
            "prototype_win_ratios",
 | 
					            "prototype_win_ratios",
 | 
				
			||||||
@@ -59,7 +54,7 @@ class GLVQ(SupervisedPrototypeModel):
 | 
				
			|||||||
    def shared_step(self, batch, batch_idx, optimizer_idx=None):
 | 
					    def shared_step(self, batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
        x, y = batch
 | 
					        x, y = batch
 | 
				
			||||||
        out = self.compute_distances(x)
 | 
					        out = self.compute_distances(x)
 | 
				
			||||||
        plabels = self.proto_layer.component_labels
 | 
					        plabels = self.proto_layer.labels
 | 
				
			||||||
        mu = self.loss(out, y, prototype_labels=plabels)
 | 
					        mu = self.loss(out, y, prototype_labels=plabels)
 | 
				
			||||||
        batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
 | 
					        batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
 | 
				
			||||||
        loss = batch_loss.sum(dim=0)
 | 
					        loss = batch_loss.sum(dim=0)
 | 
				
			||||||
@@ -135,7 +130,7 @@ class SiameseGLVQ(GLVQ):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def compute_distances(self, x):
 | 
					    def compute_distances(self, x):
 | 
				
			||||||
        protos, _ = self.proto_layer()
 | 
					        protos, _ = self.proto_layer()
 | 
				
			||||||
        x, protos = get_flat(x, protos)
 | 
					        x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
 | 
				
			||||||
        latent_x = self.backbone(x)
 | 
					        latent_x = self.backbone(x)
 | 
				
			||||||
        self.backbone.requires_grad_(self.both_path_gradients)
 | 
					        self.backbone.requires_grad_(self.both_path_gradients)
 | 
				
			||||||
        latent_protos = self.backbone(protos)
 | 
					        latent_protos = self.backbone(protos)
 | 
				
			||||||
@@ -240,17 +235,13 @@ class GMLVQ(GLVQ):
 | 
				
			|||||||
        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
					        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Additional parameters
 | 
					        # Additional parameters
 | 
				
			||||||
        omega_initializer = kwargs.get("omega_initializer", None)
 | 
					        omega_initializer = kwargs.get("omega_initializer",
 | 
				
			||||||
        initialized_omega = kwargs.get("initialized_omega", None)
 | 
					                                       EyeTransformInitializer())
 | 
				
			||||||
        if omega_initializer is not None or initialized_omega is not None:
 | 
					        omega = omega_initializer.generate(self.hparams.input_dim,
 | 
				
			||||||
            self.omega_layer = LinearMapping(
 | 
					                                           self.hparams.latent_dim)
 | 
				
			||||||
                mapping_shape=(self.hparams.input_dim, self.hparams.latent_dim),
 | 
					        self.register_parameter("_omega", Parameter(omega))
 | 
				
			||||||
                initializer=omega_initializer,
 | 
					        self.backbone = LambdaLayer(lambda x: x @ self._omega,
 | 
				
			||||||
                initialized_linearmapping=initialized_omega,
 | 
					                                    name="omega matrix")
 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.register_parameter("_omega", Parameter(self.omega_layer.mapping))
 | 
					 | 
				
			||||||
        self.backbone = LambdaLayer(lambda x: x @ self._omega, name = "omega matrix")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def omega_matrix(self):
 | 
					    def omega_matrix(self):
 | 
				
			||||||
@@ -264,24 +255,6 @@ class GMLVQ(GLVQ):
 | 
				
			|||||||
    def extra_repr(self):
 | 
					    def extra_repr(self):
 | 
				
			||||||
        return f"(omega): (shape: {tuple(self._omega.shape)})"
 | 
					        return f"(omega): (shape: {tuple(self._omega.shape)})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def predict_latent(self, x, map_protos=True):
 | 
					 | 
				
			||||||
        """Predict `x` assuming it is already embedded in the latent space.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Only the prototypes are embedded in the latent space using the
 | 
					 | 
				
			||||||
        backbone.
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        self.eval()
 | 
					 | 
				
			||||||
        with torch.no_grad():
 | 
					 | 
				
			||||||
            protos, plabels = self.proto_layer()
 | 
					 | 
				
			||||||
            if map_protos:
 | 
					 | 
				
			||||||
                protos = self.backbone(protos)
 | 
					 | 
				
			||||||
            d = squared_euclidean_distance(x, protos)
 | 
					 | 
				
			||||||
            y_pred = wtac(d, plabels)
 | 
					 | 
				
			||||||
        return y_pred
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LGMLVQ(GMLVQ):
 | 
					class LGMLVQ(GMLVQ):
 | 
				
			||||||
    """Localized and Generalized Matrix Learning Vector Quantization."""
 | 
					    """Localized and Generalized Matrix Learning Vector Quantization."""
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,9 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.components import LabeledComponents
 | 
					from ..core.competitions import KNNC
 | 
				
			||||||
from prototorch.modules import KNNC
 | 
					from ..core.components import LabeledComponents
 | 
				
			||||||
 | 
					from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
 | 
				
			||||||
 | 
					from ..utils.utils import parse_data_arg
 | 
				
			||||||
from .abstract import SupervisedPrototypeModel
 | 
					from .abstract import SupervisedPrototypeModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -19,9 +20,13 @@ class KNN(SupervisedPrototypeModel):
 | 
				
			|||||||
        data = kwargs.get("data", None)
 | 
					        data = kwargs.get("data", None)
 | 
				
			||||||
        if data is None:
 | 
					        if data is None:
 | 
				
			||||||
            raise ValueError("KNN requires data, but was not provided!")
 | 
					            raise ValueError("KNN requires data, but was not provided!")
 | 
				
			||||||
 | 
					        data, targets = parse_data_arg(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Layers
 | 
					        # Layers
 | 
				
			||||||
        self.proto_layer = LabeledComponents(initialized_components=data)
 | 
					        self.proto_layer = LabeledComponents(
 | 
				
			||||||
 | 
					            distribution=[],
 | 
				
			||||||
 | 
					            components_initializer=LiteralCompInitializer(data),
 | 
				
			||||||
 | 
					            labels_initializer=LiteralLabelsInitializer(targets))
 | 
				
			||||||
        self.competition_layer = KNNC(k=self.hparams.k)
 | 
					        self.competition_layer = KNNC(k=self.hparams.k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
					    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,6 @@
 | 
				
			|||||||
"""LVQ models that are optimized using non-gradient methods."""
 | 
					"""LVQ models that are optimized using non-gradient methods."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.functions.losses import _get_dp_dm
 | 
					from ..core.losses import _get_dp_dm
 | 
				
			||||||
 | 
					 | 
				
			||||||
from .abstract import NonGradientMixin
 | 
					from .abstract import NonGradientMixin
 | 
				
			||||||
from .glvq import GLVQ
 | 
					from .glvq import GLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -10,7 +9,7 @@ class LVQ1(NonGradientMixin, GLVQ):
 | 
				
			|||||||
    """Learning Vector Quantization 1."""
 | 
					    """Learning Vector Quantization 1."""
 | 
				
			||||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
					    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
        protos = self.proto_layer.components
 | 
					        protos = self.proto_layer.components
 | 
				
			||||||
        plabels = self.proto_layer.component_labels
 | 
					        plabels = self.proto_layer.labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x, y = train_batch
 | 
					        x, y = train_batch
 | 
				
			||||||
        dis = self.compute_distances(x)
 | 
					        dis = self.compute_distances(x)
 | 
				
			||||||
@@ -29,6 +28,8 @@ class LVQ1(NonGradientMixin, GLVQ):
 | 
				
			|||||||
            self.proto_layer.load_state_dict({"_components": updated_protos},
 | 
					            self.proto_layer.load_state_dict({"_components": updated_protos},
 | 
				
			||||||
                                             strict=False)
 | 
					                                             strict=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print(f"{dis=}")
 | 
				
			||||||
 | 
					        print(f"{y=}")
 | 
				
			||||||
        # Logging
 | 
					        # Logging
 | 
				
			||||||
        self.log_acc(dis, y, tag="train_acc")
 | 
					        self.log_acc(dis, y, tag="train_acc")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -39,7 +40,7 @@ class LVQ21(NonGradientMixin, GLVQ):
 | 
				
			|||||||
    """Learning Vector Quantization 2.1."""
 | 
					    """Learning Vector Quantization 2.1."""
 | 
				
			||||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
					    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
        protos = self.proto_layer.components
 | 
					        protos = self.proto_layer.components
 | 
				
			||||||
        plabels = self.proto_layer.component_labels
 | 
					        plabels = self.proto_layer.labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x, y = train_batch
 | 
					        x, y = train_batch
 | 
				
			||||||
        dis = self.compute_distances(x)
 | 
					        dis = self.compute_distances(x)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,13 +1,11 @@
 | 
				
			|||||||
"""Probabilistic GLVQ methods"""
 | 
					"""Probabilistic GLVQ methods"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.functions.losses import nllr_loss, rslvq_loss
 | 
					 | 
				
			||||||
from prototorch.functions.pooling import (stratified_min_pooling,
 | 
					 | 
				
			||||||
                                          stratified_sum_pooling)
 | 
					 | 
				
			||||||
from prototorch.functions.transforms import (GaussianPrior,
 | 
					 | 
				
			||||||
                                             RankScaledGaussianPrior)
 | 
					 | 
				
			||||||
from prototorch.modules import LambdaLayer, LossLayer
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.losses import nllr_loss, rslvq_loss
 | 
				
			||||||
 | 
					from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
 | 
				
			||||||
 | 
					from ..nn.wrappers import LambdaLayer, LossLayer
 | 
				
			||||||
 | 
					from .extras import GaussianPrior, RankScaledGaussianPrior
 | 
				
			||||||
from .glvq import GLVQ, SiameseGMLVQ
 | 
					from .glvq import GLVQ, SiameseGMLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -22,7 +20,7 @@ class CELVQ(GLVQ):
 | 
				
			|||||||
    def shared_step(self, batch, batch_idx, optimizer_idx=None):
 | 
					    def shared_step(self, batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
        x, y = batch
 | 
					        x, y = batch
 | 
				
			||||||
        out = self.compute_distances(x)  # [None, num_protos]
 | 
					        out = self.compute_distances(x)  # [None, num_protos]
 | 
				
			||||||
        plabels = self.proto_layer.component_labels
 | 
					        plabels = self.proto_layer.labels
 | 
				
			||||||
        winning = stratified_min_pooling(out, plabels)  # [None, num_classes]
 | 
					        winning = stratified_min_pooling(out, plabels)  # [None, num_classes]
 | 
				
			||||||
        probs = -1.0 * winning
 | 
					        probs = -1.0 * winning
 | 
				
			||||||
        batch_loss = self.loss(probs, y.long())
 | 
					        batch_loss = self.loss(probs, y.long())
 | 
				
			||||||
@@ -56,7 +54,7 @@ class ProbabilisticLVQ(GLVQ):
 | 
				
			|||||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
					    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
        x, y = batch
 | 
					        x, y = batch
 | 
				
			||||||
        out = self.forward(x)
 | 
					        out = self.forward(x)
 | 
				
			||||||
        plabels = self.proto_layer.component_labels
 | 
					        plabels = self.proto_layer.labels
 | 
				
			||||||
        batch_loss = self.loss(out, y, plabels)
 | 
					        batch_loss = self.loss(out, y, plabels)
 | 
				
			||||||
        loss = batch_loss.sum(dim=0)
 | 
					        loss = batch_loss.sum(dim=0)
 | 
				
			||||||
        return loss
 | 
					        return loss
 | 
				
			||||||
@@ -89,11 +87,10 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
 | 
				
			|||||||
            self.hparams.lambd)
 | 
					            self.hparams.lambd)
 | 
				
			||||||
        self.loss = torch.nn.KLDivLoss()
 | 
					        self.loss = torch.nn.KLDivLoss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
					    # FIXME
 | 
				
			||||||
        x, y = batch
 | 
					    # def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
        out = self.forward(x)
 | 
					    #     x, y = batch
 | 
				
			||||||
        y_dist = torch.nn.functional.one_hot(
 | 
					    #     y_pred = self(x)
 | 
				
			||||||
            y.long(), num_classes=self.num_classes).float()
 | 
					    #     batch_loss = self.loss(y_pred, y)
 | 
				
			||||||
        batch_loss = self.loss(out, y_dist)
 | 
					    #     loss = batch_loss.sum(dim=0)
 | 
				
			||||||
        loss = batch_loss.sum(dim=0)
 | 
					    #     return loss
 | 
				
			||||||
        return loss
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.functions.competitions import wtac
 | 
					 | 
				
			||||||
from prototorch.functions.distances import squared_euclidean_distance
 | 
					 | 
				
			||||||
from prototorch.modules import LambdaLayer
 | 
					 | 
				
			||||||
from prototorch.modules.losses import NeuralGasEnergy
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.competitions import wtac
 | 
				
			||||||
 | 
					from ..core.distances import squared_euclidean_distance
 | 
				
			||||||
 | 
					from ..core.losses import NeuralGasEnergy
 | 
				
			||||||
 | 
					from ..nn.wrappers import LambdaLayer
 | 
				
			||||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
 | 
					from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
 | 
				
			||||||
from .callbacks import GNGCallback
 | 
					from .callbacks import GNGCallback
 | 
				
			||||||
from .extras import ConnectionTopology
 | 
					from .extras import ConnectionTopology
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,6 +7,8 @@ import torchvision
 | 
				
			|||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
from torch.utils.data import DataLoader, Dataset
 | 
					from torch.utils.data import DataLoader, Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..utils.utils import mesh2d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Vis2DAbstract(pl.Callback):
 | 
					class Vis2DAbstract(pl.Callback):
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
@@ -73,23 +75,7 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
            ax.axis("off")
 | 
					            ax.axis("off")
 | 
				
			||||||
        return ax
 | 
					        return ax
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_mesh_input(self, x):
 | 
					    def plot_data(self, ax, x, y):
 | 
				
			||||||
        x_shift = self.border * np.ptp(x[:, 0])
 | 
					 | 
				
			||||||
        y_shift = self.border * np.ptp(x[:, 1])
 | 
					 | 
				
			||||||
        x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
 | 
					 | 
				
			||||||
        y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
 | 
					 | 
				
			||||||
        xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution),
 | 
					 | 
				
			||||||
                             np.linspace(y_min, y_max, self.resolution))
 | 
					 | 
				
			||||||
        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
					 | 
				
			||||||
        return mesh_input, xx, yy
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def perform_pca_2D(self, data):
 | 
					 | 
				
			||||||
        (_, eigVal, eigVec) = torch.pca_lowrank(data, q=2)
 | 
					 | 
				
			||||||
        return data @ eigVec
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def plot_data(self, ax, x, y, pca=False):
 | 
					 | 
				
			||||||
        if pca:
 | 
					 | 
				
			||||||
            x = self.perform_pca_2D(x)
 | 
					 | 
				
			||||||
        ax.scatter(
 | 
					        ax.scatter(
 | 
				
			||||||
            x[:, 0],
 | 
					            x[:, 0],
 | 
				
			||||||
            x[:, 1],
 | 
					            x[:, 1],
 | 
				
			||||||
@@ -100,9 +86,7 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
            s=30,
 | 
					            s=30,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def plot_protos(self, ax, protos, plabels, pca=False):
 | 
					    def plot_protos(self, ax, protos, plabels):
 | 
				
			||||||
        if pca:
 | 
					 | 
				
			||||||
            protos = self.perform_pca_2D(protos)
 | 
					 | 
				
			||||||
        ax.scatter(
 | 
					        ax.scatter(
 | 
				
			||||||
            protos[:, 0],
 | 
					            protos[:, 0],
 | 
				
			||||||
            protos[:, 1],
 | 
					            protos[:, 1],
 | 
				
			||||||
@@ -146,7 +130,7 @@ class VisGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        self.plot_data(ax, x_train, y_train)
 | 
					        self.plot_data(ax, x_train, y_train)
 | 
				
			||||||
        self.plot_protos(ax, protos, plabels)
 | 
					        self.plot_protos(ax, protos, plabels)
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
        mesh_input, xx, yy = self.get_mesh_input(x)
 | 
					        mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
 | 
				
			||||||
        _components = pl_module.proto_layer._components
 | 
					        _components = pl_module.proto_layer._components
 | 
				
			||||||
        mesh_input = torch.from_numpy(mesh_input).type_as(_components)
 | 
					        mesh_input = torch.from_numpy(mesh_input).type_as(_components)
 | 
				
			||||||
        y_pred = pl_module.predict(mesh_input)
 | 
					        y_pred = pl_module.predict(mesh_input)
 | 
				
			||||||
@@ -181,9 +165,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        if self.show_protos:
 | 
					        if self.show_protos:
 | 
				
			||||||
            self.plot_protos(ax, protos, plabels)
 | 
					            self.plot_protos(ax, protos, plabels)
 | 
				
			||||||
            x = np.vstack((x_train, protos))
 | 
					            x = np.vstack((x_train, protos))
 | 
				
			||||||
            mesh_input, xx, yy = self.get_mesh_input(x)
 | 
					            mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            mesh_input, xx, yy = self.get_mesh_input(x_train)
 | 
					            mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
 | 
				
			||||||
        _components = pl_module.proto_layer._components
 | 
					        _components = pl_module.proto_layer._components
 | 
				
			||||||
        mesh_input = torch.Tensor(mesh_input).type_as(_components)
 | 
					        mesh_input = torch.Tensor(mesh_input).type_as(_components)
 | 
				
			||||||
        y_pred = pl_module.predict_latent(mesh_input,
 | 
					        y_pred = pl_module.predict_latent(mesh_input,
 | 
				
			||||||
@@ -194,50 +178,6 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        self.log_and_display(trainer, pl_module)
 | 
					        self.log_and_display(trainer, pl_module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisGMLVQ2D(Vis2DAbstract):
 | 
					 | 
				
			||||||
    def __init__(self, *args, map_protos=True, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(*args, **kwargs)
 | 
					 | 
				
			||||||
        self.map_protos = map_protos
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					 | 
				
			||||||
        if not self.precheck(trainer):
 | 
					 | 
				
			||||||
            return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        protos = pl_module.prototypes
 | 
					 | 
				
			||||||
        plabels = pl_module.prototype_labels
 | 
					 | 
				
			||||||
        x_train, y_train = self.x_train, self.y_train
 | 
					 | 
				
			||||||
        device = pl_module.device
 | 
					 | 
				
			||||||
        with torch.no_grad():
 | 
					 | 
				
			||||||
            x_train = pl_module.backbone(torch.Tensor(x_train).to(device))
 | 
					 | 
				
			||||||
            x_train = x_train.cpu().detach()
 | 
					 | 
				
			||||||
        if self.map_protos:
 | 
					 | 
				
			||||||
            with torch.no_grad():
 | 
					 | 
				
			||||||
                protos = pl_module.backbone(torch.Tensor(protos).to(device))
 | 
					 | 
				
			||||||
                protos = protos.cpu().detach()
 | 
					 | 
				
			||||||
        ax = self.setup_ax()
 | 
					 | 
				
			||||||
        if x_train.shape[1] > 2:
 | 
					 | 
				
			||||||
            self.plot_data(ax, x_train, y_train, pca=True)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.plot_data(ax, x_train, y_train, pca=False)
 | 
					 | 
				
			||||||
        if self.show_protos:
 | 
					 | 
				
			||||||
            if protos.shape[1] > 2:
 | 
					 | 
				
			||||||
                self.plot_protos(ax, protos, plabels, pca=True)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                self.plot_protos(ax, protos, plabels, pca=False)
 | 
					 | 
				
			||||||
        ### something to work on: meshgrid with pca
 | 
					 | 
				
			||||||
        #    x = np.vstack((x_train, protos))
 | 
					 | 
				
			||||||
        #    mesh_input, xx, yy = self.get_mesh_input(x)
 | 
					 | 
				
			||||||
        #else:
 | 
					 | 
				
			||||||
        #    mesh_input, xx, yy = self.get_mesh_input(x_train)
 | 
					 | 
				
			||||||
        #_components = pl_module.proto_layer._components
 | 
					 | 
				
			||||||
        #mesh_input = torch.Tensor(mesh_input).type_as(_components)
 | 
					 | 
				
			||||||
        #y_pred = pl_module.predict_latent(mesh_input,
 | 
					 | 
				
			||||||
        #                                  map_protos=self.map_protos)
 | 
					 | 
				
			||||||
        #y_pred = y_pred.cpu().reshape(xx.shape)
 | 
					 | 
				
			||||||
        #ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
					 | 
				
			||||||
        self.log_and_display(trainer, pl_module)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class VisCBC2D(Vis2DAbstract):
 | 
					class VisCBC2D(Vis2DAbstract):
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        if not self.precheck(trainer):
 | 
					        if not self.precheck(trainer):
 | 
				
			||||||
@@ -250,8 +190,8 @@ class VisCBC2D(Vis2DAbstract):
 | 
				
			|||||||
        self.plot_data(ax, x_train, y_train)
 | 
					        self.plot_data(ax, x_train, y_train)
 | 
				
			||||||
        self.plot_protos(ax, protos, "w")
 | 
					        self.plot_protos(ax, protos, "w")
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
        mesh_input, xx, yy = self.get_mesh_input(x)
 | 
					        mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
 | 
				
			||||||
        _components = pl_module.component_layer._components
 | 
					        _components = pl_module.components_layer._components
 | 
				
			||||||
        y_pred = pl_module.predict(
 | 
					        y_pred = pl_module.predict(
 | 
				
			||||||
            torch.Tensor(mesh_input).type_as(_components))
 | 
					            torch.Tensor(mesh_input).type_as(_components))
 | 
				
			||||||
        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
					        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user