refactor(api)!: merge the new api changes into dev

This commit is contained in:
Jensun Ravichandran 2021-06-20 19:00:12 +02:00
commit d42693a441
30 changed files with 368 additions and 457 deletions

View File

@ -4,7 +4,10 @@ commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = {major}.{minor}.{patch} serialize = {major}.{minor}.{patch}
message = bump: {current_version} → {new_version}
[bumpversion:file:setup.py] [bumpversion:file:setup.py]
[bumpversion:file:./prototorch/models/__init__.py] [bumpversion:file:./prototorch/models/__init__.py]
[bumpversion:file:./docs/source/conf.py]

17
.gitignore vendored
View File

@ -128,14 +128,19 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
# Datasets
datasets/
# PyTorch-Lightning
lightning_logs/
.vscode/ .vscode/
# Vim
*~
*.swp
*.swo
# Pytorch Models or Weights # Pytorch Models or Weights
# If necessary make exceptions for single pretrained models # If necessary make exceptions for single pretrained models
*.pt *.pt
# Artifacts created by ProtoTorch Models
datasets/
lightning_logs/
examples/_*.py
examples/_*.ipynb

View File

@ -1,54 +1,53 @@
# See https://pre-commit.com for more information # See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks # See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-ast
- id: check-case-conflict
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-ast
- id: check-case-conflict
- repo: https://github.com/myint/autoflake - repo: https://github.com/myint/autoflake
rev: v1.4 rev: v1.4
hooks: hooks:
- id: autoflake - id: autoflake
- repo: http://github.com/PyCQA/isort - repo: http://github.com/PyCQA/isort
rev: 5.8.0 rev: 5.8.0
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.902' rev: v0.902
hooks: hooks:
- id: mypy - id: mypy
files: prototorch files: prototorch
additional_dependencies: [types-pkg_resources] additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/pre-commit/mirrors-yapf
rev: 'v0.31.0' # Use the sha / tag you want to point at rev: v0.31.0
hooks: hooks:
- id: yapf - id: yapf
- repo: https://github.com/pre-commit/pygrep-hooks - repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0 # Use the ref you want to point at rev: v1.9.0
hooks: hooks:
- id: python-use-type-annotations - id: python-use-type-annotations
- id: python-no-log-warn - id: python-no-log-warn
- id: python-check-blanket-noqa - id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v2.19.4
hooks:
- id: pyupgrade
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/si-cim/gitlint
rev: v2.19.4 rev: v0.15.2-unofficial
hooks: hooks:
- id: pyupgrade - id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]
- repo: https://github.com/jorisroovers/gitlint
rev: "v0.15.1"
hooks:
- id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]

View File

@ -20,23 +20,6 @@ pip install prototorch_models
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
be available for use in your Python environment as `prototorch.models`. be available for use in your Python environment as `prototorch.models`.
## Contribution
This repository contains definition for [git hooks](https://githooks.com).
[Pre-commit](https://pre-commit.com) is automatically installed as development
dependency with prototorch or you can install it manually with `pip install
pre-commit`.
Please install the hooks by running:
```bash
pre-commit install
pre-commit install --hook-type commit-msg
```
before creating the first commit.
The commit will fail if the commit message does not follow the specification
provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
## Available models ## Available models
### LVQ Family ### LVQ Family
@ -103,6 +86,23 @@ To assist in the development process, you may also find it useful to install
please avoid installing Tensorflow in this environment. It is known to cause please avoid installing Tensorflow in this environment. It is known to cause
problems with PyTorch-Lightning.** problems with PyTorch-Lightning.**
## Contribution
This repository contains definition for [git hooks](https://githooks.com).
[Pre-commit](https://pre-commit.com) is automatically installed as development
dependency with prototorch or you can install it manually with `pip install
pre-commit`.
Please install the hooks by running:
```bash
pre-commit install
pre-commit install --hook-type commit-msg
```
before creating the first commit.
The commit will fail if the commit message does not follow the specification
provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
## FAQ ## FAQ
### How do I update the plugin? ### How do I update the plugin?

View File

@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.4.4" release = "0.1.8"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@ -2,11 +2,10 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -24,14 +23,18 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[2, 2, 2], distribution=[1, 0, 3],
proto_lr=0.1, margin=0.1,
proto_lr=0.01,
bb_lr=0.01,
) )
# Initialize the model # Initialize the model
model = pt.models.CBC( model = pt.models.CBC(
hparams, hparams,
prototype_initializer=pt.components.SSI(train_ds, noise=0.01), components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
reasonings_iniitializer=pt.initializers.
PurePositiveReasoningsInitializer(),
) )
# Callbacks # Callbacks

View File

@ -2,11 +2,10 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -37,7 +36,7 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = pt.models.CELVQ( model = pt.models.CELVQ(
hparams, hparams,
prototype_initializer=pt.components.Ones(2, scale=3), prototypes_initializer=pt.initializers.FVCI(2, 3.0),
) )
# Compute intermediate input and output sizes # Compute intermediate input and output sizes

View File

@ -2,12 +2,11 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -24,7 +23,7 @@ if __name__ == "__main__":
hparams = dict( hparams = dict(
distribution={ distribution={
"num_classes": 3, "num_classes": 3,
"prototypes_per_class": 4 "per_class": 4
}, },
lr=0.01, lr=0.01,
) )
@ -33,7 +32,7 @@ if __name__ == "__main__":
model = pt.models.GLVQ( model = pt.models.GLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SMI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR, lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
) )

View File

@ -2,11 +2,10 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -26,7 +25,6 @@ if __name__ == "__main__":
distribution=(num_classes, prototypes_per_class), distribution=(num_classes, prototypes_per_class),
transfer_function="swish_beta", transfer_function="swish_beta",
transfer_beta=10.0, transfer_beta=10.0,
# lr=0.1,
proto_lr=0.1, proto_lr=0.1,
bb_lr=0.1, bb_lr=0.1,
input_dim=2, input_dim=2,
@ -37,7 +35,7 @@ if __name__ == "__main__":
model = pt.models.GMLVQ( model = pt.models.GMLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SSI(train_ds, noise=1e-2), prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
) )
# Callbacks # Callbacks
@ -47,12 +45,12 @@ if __name__ == "__main__":
block=False, block=False,
) )
pruning = pt.models.PruneLoserPrototypes( pruning = pt.models.PruneLoserPrototypes(
threshold=0.02, threshold=0.01,
idle_epochs=10, idle_epochs=10,
prune_quota_per_epoch=5, prune_quota_per_epoch=5,
frequency=2, frequency=5,
replace=True, replace=True,
initializer=pt.components.SSI(train_ds, noise=1e-2), prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
verbose=True, verbose=True,
) )
es = pl.callbacks.EarlyStopping( es = pl.callbacks.EarlyStopping(
@ -68,7 +66,7 @@ if __name__ == "__main__":
args, args,
callbacks=[ callbacks=[
vis, vis,
# es, # es, # FIXME
pruning, pruning,
], ],
terminate_on_nan=True, terminate_on_nan=True,

View File

@ -1,59 +0,0 @@
"""GLVQ example using the Iris dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
input_dim=4,
latent_dim=3,
distribution={
"num_classes": 3,
"prototypes_per_class": 2
},
proto_lr=0.0005,
bb_lr=0.0005,
)
# Initialize the model
model = pt.models.GMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SSI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
omega_initializer=pt.components.PCA(train_ds.data)
)
# Compute intermediate input and output sizes
#model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisGMLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -2,11 +2,10 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -30,7 +29,7 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = pt.models.GrowingNeuralGas( model = pt.models.GrowingNeuralGas(
hparams, hparams,
prototype_initializer=pt.components.Zeros(2), prototypes_initializer=pt.initializers.ZCI(2),
) )
# Compute intermediate input and output sizes # Compute intermediate input and output sizes

View File

@ -2,25 +2,11 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.utils.colors import hex_to_rgb
import prototorch as pt
def hex_to_rgb(hex_values):
for v in hex_values:
v = v.lstrip('#')
lv = len(v)
c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)]
yield c
def rgb_to_hex(rgb_values):
for v in rgb_values:
c = "%02x%02x%02x" % tuple(v)
yield c
class Vis2DColorSOM(pl.Callback): class Vis2DColorSOM(pl.Callback):
@ -93,7 +79,7 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = pt.models.KohonenSOM( model = pt.models.KohonenSOM(
hparams, hparams,
prototype_initializer=pt.components.Random(3), prototypes_initializer=pt.initializers.RNCI(3),
) )
# Compute intermediate input and output sizes # Compute intermediate input and output sizes

View File

@ -2,23 +2,22 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args() args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=2) pl.utilities.seed.seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=256, batch_size=256,
@ -32,8 +31,10 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.LGMLVQ(hparams, model = pt.models.LGMLVQ(
prototype_initializer=pt.components.SMI(train_ds)) hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
)
# Compute intermediate input and output sizes # Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)

View File

@ -3,11 +3,10 @@
import argparse import argparse
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
def plot_matrix(matrix): def plot_matrix(matrix):
title = "Lambda matrix" title = "Lambda matrix"
@ -40,20 +39,19 @@ if __name__ == "__main__":
hparams = dict( hparams = dict(
distribution={ distribution={
"num_classes": 2, "num_classes": 2,
"prototypes_per_class": 1 "per_class": 1,
}, },
input_dim=100, input_dim=100,
latent_dim=2, latent_dim=2,
proto_lr=0.0001, proto_lr=0.001,
bb_lr=0.0001, bb_lr=0.001,
) )
# Initialize the model # Initialize the model
model = pt.models.SiameseGMLVQ( model = pt.models.SiameseGMLVQ(
hparams, hparams,
# optimizer=torch.optim.SGD,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SMI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
) )
# Summary # Summary

View File

@ -2,11 +2,10 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
class Backbone(torch.nn.Module): class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2): def __init__(self, input_size=4, hidden_size=10, latent_size=2):
@ -41,7 +40,7 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[1, 2, 2], distribution=[3, 4, 5],
proto_lr=0.001, proto_lr=0.001,
bb_lr=0.001, bb_lr=0.001,
) )
@ -52,7 +51,10 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = pt.models.LVQMLN( model = pt.models.LVQMLN(
hparams, hparams,
prototype_initializer=pt.components.SSI(train_ds, transform=backbone), prototypes_initializer=pt.initializers.SSCI(
train_ds,
transform=backbone,
),
backbone=backbone, backbone=backbone,
) )
@ -67,11 +69,21 @@ if __name__ == "__main__":
resolution=500, resolution=500,
axis_off=True, axis_off=True,
) )
pruning = pt.models.PruneLoserPrototypes(
threshold=0.01,
idle_epochs=20,
prune_quota_per_epoch=2,
frequency=10,
verbose=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
vis,
pruning,
],
) )
# Training loop # Training loop

View File

@ -2,11 +2,9 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torchvision.transforms import Lambda
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -28,19 +26,17 @@ if __name__ == "__main__":
distribution=[2, 2, 3], distribution=[2, 2, 3],
proto_lr=0.05, proto_lr=0.05,
lambd=0.1, lambd=0.1,
variance=1.0,
input_dim=2, input_dim=2,
latent_dim=2, latent_dim=2,
bb_lr=0.01, bb_lr=0.01,
) )
# Initialize the model # Initialize the model
model = pt.models.probabilistic.PLVQ( model = pt.models.RSLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
# prototype_initializer=pt.components.SMI(train_ds), prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
# prototype_initializer=pt.components.Zeros(2),
# prototype_initializer=pt.components.Ones(2, scale=2.0),
) )
# Compute intermediate input and output sizes # Compute intermediate input and output sizes
@ -50,7 +46,7 @@ if __name__ == "__main__":
print(model) print(model)
# Callbacks # Callbacks
vis = pt.models.VisSiameseGLVQ2D(data=train_ds) vis = pt.models.VisGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(

View File

@ -2,11 +2,10 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
class Backbone(torch.nn.Module): class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2): def __init__(self, input_size=4, hidden_size=10, latent_size=2):
@ -52,7 +51,7 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = pt.models.SiameseGLVQ( model = pt.models.SiameseGLVQ(
hparams, hparams,
prototype_initializer=pt.components.SMI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone, backbone=backbone,
both_path_gradients=False, both_path_gradients=False,
) )

84
examples/warm_starting.py Normal file
View File

@ -0,0 +1,84 @@
"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Initialize the gng
gng = pt.models.GrowingNeuralGas(
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
prototypes_initializer=pt.initializers.ZCI(2),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Callbacks
es = pl.callbacks.EarlyStopping(
monitor="loss",
min_delta=0.001,
patience=20,
mode="min",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer for GNG
trainer = pl.Trainer(
max_epochs=200,
callbacks=[es],
weights_summary=None,
)
# Training loop
trainer.fit(gng, train_loader)
# Hyperparameters
hparams = dict(
distribution=[],
lr=0.01,
)
# Warm-start prototypes
knn = pt.models.KNN(dict(k=1), data=train_ds)
prototypes = gng.prototypes
plabels = knn.predict(prototypes)
# Initialize the model
model = pt.models.GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.LCI(prototypes),
labels_initializer=pt.initializers.LLI(plabels),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -4,8 +4,19 @@ from importlib.metadata import PackageNotFoundError, version
from .callbacks import PrototypeConvergence, PruneLoserPrototypes from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC from .cbc import CBC, ImageCBC
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN, from .glvq import (
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ) GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .knn import KNN from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ

View File

@ -5,9 +5,13 @@ from typing import Final, final
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.components import Components, LabeledComponents
from prototorch.functions.distances import euclidean_distance from ..core.competitions import WTAC
from prototorch.modules import WTAC, LambdaLayer from ..core.components import Components, LabeledComponents
from ..core.distances import euclidean_distance
from ..core.initializers import LabelsInitializer
from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
class ProtoTorchMixin(object): class ProtoTorchMixin(object):
@ -85,13 +89,11 @@ class UnsupervisedPrototypeModel(PrototypeModel):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Layers # Layers
prototype_initializer = kwargs.get("prototype_initializer", None) prototypes_initializer = kwargs.get("prototypes_initializer", None)
initialized_prototypes = kwargs.get("initialized_prototypes", None) if prototypes_initializer is not None:
if prototype_initializer is not None or initialized_prototypes is not None:
self.proto_layer = Components( self.proto_layer = Components(
self.hparams.num_prototypes, self.hparams.num_prototypes,
initializer=prototype_initializer, initializer=prototypes_initializer,
initialized_components=initialized_prototypes,
) )
def compute_distances(self, x): def compute_distances(self, x):
@ -109,23 +111,24 @@ class SupervisedPrototypeModel(PrototypeModel):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Layers # Layers
prototype_initializer = kwargs.get("prototype_initializer", None) prototypes_initializer = kwargs.get("prototypes_initializer", None)
initialized_prototypes = kwargs.get("initialized_prototypes", None) labels_initializer = kwargs.get("labels_initializer",
if prototype_initializer is not None or initialized_prototypes is not None: LabelsInitializer())
if prototypes_initializer is not None:
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution, distribution=self.hparams.distribution,
initializer=prototype_initializer, components_initializer=prototypes_initializer,
initialized_components=initialized_prototypes, labels_initializer=labels_initializer,
) )
self.competition_layer = WTAC() self.competition_layer = WTAC()
@property @property
def prototype_labels(self): def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu() return self.proto_layer.labels.detach().cpu()
@property @property
def num_classes(self): def num_classes(self):
return len(self.proto_layer.distribution) return self.proto_layer.num_classes
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
@ -134,15 +137,14 @@ class SupervisedPrototypeModel(PrototypeModel):
def forward(self, x): def forward(self, x):
distances = self.compute_distances(x) distances = self.compute_distances(x)
y_pred = self.predict_from_distances(distances) plabels = self.proto_layer.labels
# TODO winning = stratified_min_pooling(distances, plabels)
y_pred = torch.eye(self.num_classes, device=self.device)[ y_pred = torch.nn.functional.softmin(winning)
y_pred.long()] # depends on labels {0,...,num_classes}
return y_pred return y_pred
def predict_from_distances(self, distances): def predict_from_distances(self, distances):
with torch.no_grad(): with torch.no_grad():
plabels = self.proto_layer.component_labels plabels = self.proto_layer.labels
y_pred = self.competition_layer(distances, plabels) y_pred = self.competition_layer(distances, plabels)
return y_pred return y_pred

View File

@ -4,8 +4,9 @@ import logging
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.components import Components
from ..core.components import Components
from ..core.initializers import LiteralCompInitializer
from .extras import ConnectionTopology from .extras import ConnectionTopology
@ -16,7 +17,7 @@ class PruneLoserPrototypes(pl.Callback):
prune_quota_per_epoch=-1, prune_quota_per_epoch=-1,
frequency=1, frequency=1,
replace=False, replace=False,
initializer=None, prototypes_initializer=None,
verbose=False): verbose=False):
self.threshold = threshold # minimum win ratio self.threshold = threshold # minimum win ratio
self.idle_epochs = idle_epochs # epochs to wait before pruning self.idle_epochs = idle_epochs # epochs to wait before pruning
@ -24,7 +25,7 @@ class PruneLoserPrototypes(pl.Callback):
self.frequency = frequency self.frequency = frequency
self.replace = replace self.replace = replace
self.verbose = verbose self.verbose = verbose
self.initializer = initializer self.prototypes_initializer = prototypes_initializer
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.idle_epochs: if (trainer.current_epoch + 1) < self.idle_epochs:
@ -55,8 +56,9 @@ class PruneLoserPrototypes(pl.Callback):
if self.verbose: if self.verbose:
print(f"Re-adding pruned prototypes...") print(f"Re-adding pruned prototypes...")
print(f"{distribution=}") print(f"{distribution=}")
pl_module.add_prototypes(distribution=distribution, pl_module.add_prototypes(
initializer=self.initializer) distribution=distribution,
components_initializer=self.prototypes_initializer)
new_num_protos = pl_module.num_prototypes new_num_protos = pl_module.num_prototypes
if self.verbose: if self.verbose:
print(f"`num_prototypes` changed from {cur_num_protos} " print(f"`num_prototypes` changed from {cur_num_protos} "
@ -116,7 +118,8 @@ class GNGCallback(pl.Callback):
# Add component # Add component
pl_module.proto_layer.add_components( pl_module.proto_layer.add_components(
initialized_components=new_component.unsqueeze(0)) None,
initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
# Adjust Topology # Adjust Topology
topology.add_prototype() topology.add_prototype()

View File

@ -1,49 +1,54 @@
import torch import torch
import torchmetrics import torchmetrics
from ..core.competitions import CBCC
from ..core.components import ReasoningComponents
from ..core.initializers import RandomReasoningsInitializer
from ..core.losses import MarginLoss
from ..core.similarities import euclidean_similarity
from ..nn.wrappers import LambdaLayer
from .abstract import ImagePrototypesMixin from .abstract import ImagePrototypesMixin
from .extras import (CosineSimilarity, MarginLoss, ReasoningLayer,
euclidean_similarity, rescaled_cosine_similarity,
shift_activation)
from .glvq import SiameseGLVQ from .glvq import SiameseGLVQ
class CBC(SiameseGLVQ): class CBC(SiameseGLVQ):
"""Classification-By-Components.""" """Classification-By-Components."""
def __init__(self, hparams, margin=0.1, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.margin = margin
self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
num_components = self.components.shape[0]
self.reasoning_layer = ReasoningLayer(num_components=num_components,
num_classes=self.num_classes)
self.component_layer = self.proto_layer
@property similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
def components(self): components_initializer = kwargs.get("components_initializer", None)
return self.prototypes reasonings_initializer = kwargs.get("reasonings_initializer",
RandomReasoningsInitializer())
self.components_layer = ReasoningComponents(
self.hparams.distribution,
components_initializer=components_initializer,
reasonings_initializer=reasonings_initializer,
)
self.similarity_layer = LambdaLayer(similarity_fn)
self.competition_layer = CBCC()
@property # Namespace hook
def reasonings(self): self.proto_layer = self.components_layer
return self.reasoning_layer.reasonings.cpu()
self.loss = MarginLoss(self.hparams.margin)
def forward(self, x): def forward(self, x):
components, _ = self.component_layer() components, reasonings = self.components_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients) self.backbone.requires_grad_(self.both_path_gradients)
latent_components = self.backbone(components) latent_components = self.backbone(components)
self.backbone.requires_grad_(True) self.backbone.requires_grad_(True)
detections = self.similarity_fn(latent_x, latent_components) detections = self.similarity_layer(latent_x, latent_components)
probs = self.reasoning_layer(detections) probs = self.competition_layer(detections, reasonings)
return probs return probs
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
# x = x.view(x.size(0), -1)
y_pred = self(x) y_pred = self(x)
num_classes = self.reasoning_layer.num_classes num_classes = self.num_classes
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes) y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0) loss = self.loss(y_pred, y_true).mean(dim=0)
return y_pred, loss return y_pred, loss
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
@ -70,7 +75,3 @@ class ImageCBC(ImagePrototypesMixin, CBC):
"""CBC model that constrains the components to the range [0, 1] by """CBC model that constrains the components to the range [0, 1] by
clamping after updates. clamping after updates.
""" """
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Namespace hook
self.proto_layer = self.component_layer

View File

@ -5,23 +5,32 @@ Modules not yet available in prototorch go here temporarily.
""" """
import torch import torch
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity from ..core.similarities import gaussian
def rescaled_cosine_similarity(x, y): def rank_scaled_gaussian(distances, lambd):
"""Cosine Similarity rescaled to [0, 1].""" order = torch.argsort(distances, dim=1)
similarities = cosine_similarity(x, y) ranks = torch.argsort(order, dim=1)
return (similarities + 1.0) / 2.0 return torch.exp(-torch.exp(-ranks / lambd) * distances)
def shift_activation(x): class GaussianPrior(torch.nn.Module):
return (x + 1.0) / 2.0 def __init__(self, variance):
super().__init__()
self.variance = variance
def forward(self, distances):
return gaussian(distances, self.variance)
def euclidean_similarity(x, y, variance=1.0): class RankScaledGaussianPrior(torch.nn.Module):
d = euclidean_distance(x, y) def __init__(self, lambd):
return torch.exp(-(d * d) / (2 * variance)) super().__init__()
self.lambd = lambd
def forward(self, distances):
return rank_scaled_gaussian(distances, self.lambd)
class ConnectionTopology(torch.nn.Module): class ConnectionTopology(torch.nn.Module):
@ -79,64 +88,3 @@ class ConnectionTopology(torch.nn.Module):
def extra_repr(self): def extra_repr(self):
return f"(agelimit): ({self.agelimit})" return f"(agelimit): ({self.agelimit})"
class CosineSimilarity(torch.nn.Module):
def __init__(self, activation=shift_activation):
super().__init__()
self.activation = activation
def forward(self, x, y):
epsilon = torch.finfo(x.dtype).eps
normed_x = (x / x.pow(2).sum(dim=tuple(range(
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
normed_y = (y / y.pow(2).sum(dim=tuple(range(
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
# normed_x = (x / torch.linalg.norm(x, dim=1))
diss = torch.inner(normed_x, normed_y)
return self.activation(diss)
class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self,
margin=0.3,
size_average=None,
reduce=None,
reduction="mean"):
super().__init__(size_average, reduce, reduction)
self.margin = margin
def forward(self, input_, target):
dp = torch.sum(target * input_, dim=-1)
dm = torch.max(input_ - target, dim=-1).values
return torch.nn.functional.relu(dm - dp + self.margin)
class ReasoningLayer(torch.nn.Module):
def __init__(self, num_components, num_classes, num_replicas=1):
super().__init__()
self.num_replicas = num_replicas
self.num_classes = num_classes
probabilities_init = torch.zeros(2, 1, num_components,
self.num_classes)
probabilities_init.uniform_(0.4, 0.6)
# TODO Use `self.register_parameter("param", Paramater(param))` instead
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@property
def reasonings(self):
pk = self.reasoning_probabilities[0]
nk = (1 - pk) * self.reasoning_probabilities[1]
ik = 1 - pk - nk
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1)
def forward(self, detections):
pk = self.reasoning_probabilities[0].clamp(0, 1)
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
numerator = (detections @ (pk - nk)) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
probs = probs.squeeze(0)
return probs

View File

@ -1,16 +1,14 @@
"""Models based on the GLVQ framework.""" """Models based on the GLVQ framework."""
import torch import torch
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (lomega_distance, omega_distance,
squared_euclidean_distance)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.components import LinearMapping
from prototorch.modules import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from ..core.initializers import EyeTransformInitializer
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
@ -30,9 +28,6 @@ class GLVQ(SupervisedPrototypeModel):
# Loss # Loss
self.loss = LossLayer(glvq_loss) self.loss = LossLayer(glvq_loss)
# Prototype metrics
self.initialize_prototype_win_ratios()
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.register_buffer( self.register_buffer(
"prototype_win_ratios", "prototype_win_ratios",
@ -59,7 +54,7 @@ class GLVQ(SupervisedPrototypeModel):
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self.compute_distances(x) out = self.compute_distances(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.labels
mu = self.loss(out, y, prototype_labels=plabels) mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta) batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0) loss = batch_loss.sum(dim=0)
@ -135,7 +130,7 @@ class SiameseGLVQ(GLVQ):
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = get_flat(x, protos) x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x) latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients) self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos) latent_protos = self.backbone(protos)
@ -240,18 +235,14 @@ class GMLVQ(GLVQ):
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters # Additional parameters
omega_initializer = kwargs.get("omega_initializer", None) omega_initializer = kwargs.get("omega_initializer",
initialized_omega = kwargs.get("initialized_omega", None) EyeTransformInitializer())
if omega_initializer is not None or initialized_omega is not None: omega = omega_initializer.generate(self.hparams.input_dim,
self.omega_layer = LinearMapping( self.hparams.latent_dim)
mapping_shape=(self.hparams.input_dim, self.hparams.latent_dim), self.register_parameter("_omega", Parameter(omega))
initializer=omega_initializer, self.backbone = LambdaLayer(lambda x: x @ self._omega,
initialized_linearmapping=initialized_omega, name="omega matrix")
)
self.register_parameter("_omega", Parameter(self.omega_layer.mapping))
self.backbone = LambdaLayer(lambda x: x @ self._omega, name = "omega matrix")
@property @property
def omega_matrix(self): def omega_matrix(self):
return self._omega.detach().cpu() return self._omega.detach().cpu()
@ -264,24 +255,6 @@ class GMLVQ(GLVQ):
def extra_repr(self): def extra_repr(self):
return f"(omega): (shape: {tuple(self._omega.shape)})" return f"(omega): (shape: {tuple(self._omega.shape)})"
def predict_latent(self, x, map_protos=True):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
self.eval()
with torch.no_grad():
protos, plabels = self.proto_layer()
if map_protos:
protos = self.backbone(protos)
d = squared_euclidean_distance(x, protos)
y_pred = wtac(d, plabels)
return y_pred
class LGMLVQ(GMLVQ): class LGMLVQ(GMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization.""" """Localized and Generalized Matrix Learning Vector Quantization."""

View File

@ -2,9 +2,10 @@
import warnings import warnings
from prototorch.components import LabeledComponents from ..core.competitions import KNNC
from prototorch.modules import KNNC from ..core.components import LabeledComponents
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
from ..utils.utils import parse_data_arg
from .abstract import SupervisedPrototypeModel from .abstract import SupervisedPrototypeModel
@ -19,9 +20,13 @@ class KNN(SupervisedPrototypeModel):
data = kwargs.get("data", None) data = kwargs.get("data", None)
if data is None: if data is None:
raise ValueError("KNN requires data, but was not provided!") raise ValueError("KNN requires data, but was not provided!")
data, targets = parse_data_arg(data)
# Layers # Layers
self.proto_layer = LabeledComponents(initialized_components=data) self.proto_layer = LabeledComponents(
distribution=[],
components_initializer=LiteralCompInitializer(data),
labels_initializer=LiteralLabelsInitializer(targets))
self.competition_layer = KNNC(k=self.hparams.k) self.competition_layer = KNNC(k=self.hparams.k)
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):

View File

@ -1,7 +1,6 @@
"""LVQ models that are optimized using non-gradient methods.""" """LVQ models that are optimized using non-gradient methods."""
from prototorch.functions.losses import _get_dp_dm from ..core.losses import _get_dp_dm
from .abstract import NonGradientMixin from .abstract import NonGradientMixin
from .glvq import GLVQ from .glvq import GLVQ
@ -10,7 +9,7 @@ class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1.""" """Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components protos = self.proto_layer.components
plabels = self.proto_layer.component_labels plabels = self.proto_layer.labels
x, y = train_batch x, y = train_batch
dis = self.compute_distances(x) dis = self.compute_distances(x)
@ -29,6 +28,8 @@ class LVQ1(NonGradientMixin, GLVQ):
self.proto_layer.load_state_dict({"_components": updated_protos}, self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False) strict=False)
print(f"{dis=}")
print(f"{y=}")
# Logging # Logging
self.log_acc(dis, y, tag="train_acc") self.log_acc(dis, y, tag="train_acc")
@ -39,7 +40,7 @@ class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1.""" """Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components protos = self.proto_layer.components
plabels = self.proto_layer.component_labels plabels = self.proto_layer.labels
x, y = train_batch x, y = train_batch
dis = self.compute_distances(x) dis = self.compute_distances(x)

View File

@ -1,13 +1,11 @@
"""Probabilistic GLVQ methods""" """Probabilistic GLVQ methods"""
import torch import torch
from prototorch.functions.losses import nllr_loss, rslvq_loss
from prototorch.functions.pooling import (stratified_min_pooling,
stratified_sum_pooling)
from prototorch.functions.transforms import (GaussianPrior,
RankScaledGaussianPrior)
from prototorch.modules import LambdaLayer, LossLayer
from ..core.losses import nllr_loss, rslvq_loss
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
from ..nn.wrappers import LambdaLayer, LossLayer
from .extras import GaussianPrior, RankScaledGaussianPrior
from .glvq import GLVQ, SiameseGMLVQ from .glvq import GLVQ, SiameseGMLVQ
@ -22,7 +20,7 @@ class CELVQ(GLVQ):
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self.compute_distances(x) # [None, num_protos] out = self.compute_distances(x) # [None, num_protos]
plabels = self.proto_layer.component_labels plabels = self.proto_layer.labels
winning = stratified_min_pooling(out, plabels) # [None, num_classes] winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning probs = -1.0 * winning
batch_loss = self.loss(probs, y.long()) batch_loss = self.loss(probs, y.long())
@ -56,7 +54,7 @@ class ProbabilisticLVQ(GLVQ):
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self.forward(x) out = self.forward(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.labels
batch_loss = self.loss(out, y, plabels) batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum(dim=0) loss = batch_loss.sum(dim=0)
return loss return loss
@ -89,11 +87,10 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
self.hparams.lambd) self.hparams.lambd)
self.loss = torch.nn.KLDivLoss() self.loss = torch.nn.KLDivLoss()
def training_step(self, batch, batch_idx, optimizer_idx=None): # FIXME
x, y = batch # def training_step(self, batch, batch_idx, optimizer_idx=None):
out = self.forward(x) # x, y = batch
y_dist = torch.nn.functional.one_hot( # y_pred = self(x)
y.long(), num_classes=self.num_classes).float() # batch_loss = self.loss(y_pred, y)
batch_loss = self.loss(out, y_dist) # loss = batch_loss.sum(dim=0)
loss = batch_loss.sum(dim=0) # return loss
return loss

View File

@ -2,11 +2,11 @@
import numpy as np import numpy as np
import torch import torch
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import squared_euclidean_distance
from prototorch.modules import LambdaLayer
from prototorch.modules.losses import NeuralGasEnergy
from ..core.competitions import wtac
from ..core.distances import squared_euclidean_distance
from ..core.losses import NeuralGasEnergy
from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
from .callbacks import GNGCallback from .callbacks import GNGCallback
from .extras import ConnectionTopology from .extras import ConnectionTopology

View File

@ -7,6 +7,8 @@ import torchvision
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from ..utils.utils import mesh2d
class Vis2DAbstract(pl.Callback): class Vis2DAbstract(pl.Callback):
def __init__(self, def __init__(self,
@ -73,23 +75,7 @@ class Vis2DAbstract(pl.Callback):
ax.axis("off") ax.axis("off")
return ax return ax
def get_mesh_input(self, x): def plot_data(self, ax, x, y):
x_shift = self.border * np.ptp(x[:, 0])
y_shift = self.border * np.ptp(x[:, 1])
x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution),
np.linspace(y_min, y_max, self.resolution))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
return mesh_input, xx, yy
def perform_pca_2D(self, data):
(_, eigVal, eigVec) = torch.pca_lowrank(data, q=2)
return data @ eigVec
def plot_data(self, ax, x, y, pca=False):
if pca:
x = self.perform_pca_2D(x)
ax.scatter( ax.scatter(
x[:, 0], x[:, 0],
x[:, 1], x[:, 1],
@ -100,9 +86,7 @@ class Vis2DAbstract(pl.Callback):
s=30, s=30,
) )
def plot_protos(self, ax, protos, plabels, pca=False): def plot_protos(self, ax, protos, plabels):
if pca:
protos = self.perform_pca_2D(protos)
ax.scatter( ax.scatter(
protos[:, 0], protos[:, 0],
protos[:, 1], protos[:, 1],
@ -146,7 +130,7 @@ class VisGLVQ2D(Vis2DAbstract):
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components) mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input) y_pred = pl_module.predict(mesh_input)
@ -181,9 +165,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
if self.show_protos: if self.show_protos:
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
else: else:
mesh_input, xx, yy = self.get_mesh_input(x_train) mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
mesh_input = torch.Tensor(mesh_input).type_as(_components) mesh_input = torch.Tensor(mesh_input).type_as(_components)
y_pred = pl_module.predict_latent(mesh_input, y_pred = pl_module.predict_latent(mesh_input,
@ -194,50 +178,6 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, map_protos=True, **kwargs):
super().__init__(*args, **kwargs)
self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
with torch.no_grad():
x_train = pl_module.backbone(torch.Tensor(x_train).to(device))
x_train = x_train.cpu().detach()
if self.map_protos:
with torch.no_grad():
protos = pl_module.backbone(torch.Tensor(protos).to(device))
protos = protos.cpu().detach()
ax = self.setup_ax()
if x_train.shape[1] > 2:
self.plot_data(ax, x_train, y_train, pca=True)
else:
self.plot_data(ax, x_train, y_train, pca=False)
if self.show_protos:
if protos.shape[1] > 2:
self.plot_protos(ax, protos, plabels, pca=True)
else:
self.plot_protos(ax, protos, plabels, pca=False)
### something to work on: meshgrid with pca
# x = np.vstack((x_train, protos))
# mesh_input, xx, yy = self.get_mesh_input(x)
#else:
# mesh_input, xx, yy = self.get_mesh_input(x_train)
#_components = pl_module.proto_layer._components
#mesh_input = torch.Tensor(mesh_input).type_as(_components)
#y_pred = pl_module.predict_latent(mesh_input,
# map_protos=self.map_protos)
#y_pred = y_pred.cpu().reshape(xx.shape)
#ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisCBC2D(Vis2DAbstract): class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer): if not self.precheck(trainer):
@ -250,8 +190,8 @@ class VisCBC2D(Vis2DAbstract):
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w") self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.component_layer._components _components = pl_module.components_layer._components
y_pred = pl_module.predict( y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components)) torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape) y_pred = y_pred.cpu().reshape(xx.shape)

8
setup.cfg Normal file
View File

@ -0,0 +1,8 @@
[isort]
profile = hug
src_paths = isort, test
[yapf]
based_on_style = pep8
spaces_before_comment = 2
split_before_logical_operator = true