diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 04da7c4..0b774e0 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -4,7 +4,10 @@ commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+) serialize = {major}.{minor}.{patch} +message = bump: {current_version} → {new_version} [bumpversion:file:setup.py] [bumpversion:file:./prototorch/models/__init__.py] + +[bumpversion:file:./docs/source/conf.py] diff --git a/.gitignore b/.gitignore index 4defab3..15f89b6 100644 --- a/.gitignore +++ b/.gitignore @@ -128,14 +128,19 @@ dmypy.json # Pyre type checker .pyre/ -# Datasets -datasets/ - -# PyTorch-Lightning -lightning_logs/ - .vscode/ +# Vim +*~ +*.swp +*.swo + # Pytorch Models or Weights # If necessary make exceptions for single pretrained models *.pt + +# Artifacts created by ProtoTorch Models +datasets/ +lightning_logs/ +examples/_*.py +examples/_*.ipynb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cf37edc..001227c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,54 +1,53 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks -repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 - hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - - id: check-yaml - - id: check-added-large-files - - id: check-ast - - id: check-case-conflict +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - id: check-ast + - id: check-case-conflict - repo: https://github.com/myint/autoflake rev: v1.4 hooks: - - id: autoflake + - id: autoflake - repo: http://github.com/PyCQA/isort rev: 5.8.0 hooks: - - id: isort + - id: isort -- repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v0.902' - hooks: - - id: mypy - files: prototorch - additional_dependencies: [types-pkg_resources] +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.902 + hooks: + - id: mypy + files: prototorch + additional_dependencies: [types-pkg_resources] -- repo: https://github.com/pre-commit/mirrors-yapf - rev: 'v0.31.0' # Use the sha / tag you want to point at - hooks: - - id: yapf +- repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.31.0 + hooks: + - id: yapf -- repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.9.0 # Use the ref you want to point at - hooks: - - id: python-use-type-annotations - - id: python-no-log-warn - - id: python-check-blanket-noqa +- repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.9.0 + hooks: + - id: python-use-type-annotations + - id: python-no-log-warn + - id: python-check-blanket-noqa +- repo: https://github.com/asottile/pyupgrade + rev: v2.19.4 + hooks: + - id: pyupgrade -- repo: https://github.com/asottile/pyupgrade - rev: v2.19.4 - hooks: - - id: pyupgrade - -- repo: https://github.com/jorisroovers/gitlint - rev: "v0.15.1" - hooks: - - id: gitlint - args: [--contrib=CT1, --ignore=B6, --msg-filename] +- repo: https://github.com/si-cim/gitlint + rev: v0.15.2-unofficial + hooks: + - id: gitlint + args: [--contrib=CT1, --ignore=B6, --msg-filename] diff --git a/README.md b/README.md index d5bd355..545e523 100644 --- a/README.md +++ b/README.md @@ -20,23 +20,6 @@ pip install prototorch_models of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then be available for use in your Python environment as `prototorch.models`. -## Contribution - -This repository contains definition for [git hooks](https://githooks.com). -[Pre-commit](https://pre-commit.com) is automatically installed as development -dependency with prototorch or you can install it manually with `pip install -pre-commit`. - -Please install the hooks by running: -```bash -pre-commit install -pre-commit install --hook-type commit-msg -``` -before creating the first commit. - -The commit will fail if the commit message does not follow the specification -provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification). - ## Available models ### LVQ Family @@ -103,6 +86,23 @@ To assist in the development process, you may also find it useful to install please avoid installing Tensorflow in this environment. It is known to cause problems with PyTorch-Lightning.** +## Contribution + +This repository contains definition for [git hooks](https://githooks.com). +[Pre-commit](https://pre-commit.com) is automatically installed as development +dependency with prototorch or you can install it manually with `pip install +pre-commit`. + +Please install the hooks by running: +```bash +pre-commit install +pre-commit install --hook-type commit-msg +``` +before creating the first commit. + +The commit will fail if the commit message does not follow the specification +provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification). + ## FAQ ### How do I update the plugin? diff --git a/docs/source/conf.py b/docs/source/conf.py index 63cbd09..5db5266 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -23,7 +23,7 @@ author = "Jensun Ravichandran" # The full version, including alpha/beta/rc tags # -release = "0.4.4" +release = "0.1.8" # -- General configuration --------------------------------------------------- diff --git a/examples/cbc_iris.py b/examples/cbc_iris.py index e204e81..f0561af 100644 --- a/examples/cbc_iris.py +++ b/examples/cbc_iris.py @@ -2,11 +2,10 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch -import prototorch as pt - if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() @@ -24,14 +23,18 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( - distribution=[2, 2, 2], - proto_lr=0.1, + distribution=[1, 0, 3], + margin=0.1, + proto_lr=0.01, + bb_lr=0.01, ) # Initialize the model model = pt.models.CBC( hparams, - prototype_initializer=pt.components.SSI(train_ds, noise=0.01), + components_initializer=pt.initializers.SSCI(train_ds, noise=0.01), + reasonings_iniitializer=pt.initializers. + PurePositiveReasoningsInitializer(), ) # Callbacks diff --git a/examples/dynamic_pruning.py b/examples/dynamic_pruning.py index 9104393..454b66e 100644 --- a/examples/dynamic_pruning.py +++ b/examples/dynamic_pruning.py @@ -2,11 +2,10 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch -import prototorch as pt - if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() @@ -37,7 +36,7 @@ if __name__ == "__main__": # Initialize the model model = pt.models.CELVQ( hparams, - prototype_initializer=pt.components.Ones(2, scale=3), + prototypes_initializer=pt.initializers.FVCI(2, 3.0), ) # Compute intermediate input and output sizes diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 328911e..f9556ae 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -2,12 +2,11 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch from torch.optim.lr_scheduler import ExponentialLR -import prototorch as pt - if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() @@ -24,7 +23,7 @@ if __name__ == "__main__": hparams = dict( distribution={ "num_classes": 3, - "prototypes_per_class": 4 + "per_class": 4 }, lr=0.01, ) @@ -33,7 +32,7 @@ if __name__ == "__main__": model = pt.models.GLVQ( hparams, optimizer=torch.optim.Adam, - prototype_initializer=pt.components.SMI(train_ds), + prototypes_initializer=pt.initializers.SMCI(train_ds), lr_scheduler=ExponentialLR, lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), ) diff --git a/examples/glvq_spiral.py b/examples/glvq_spiral.py index f6092a1..386c01a 100644 --- a/examples/glvq_spiral.py +++ b/examples/glvq_spiral.py @@ -2,11 +2,10 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch -import prototorch as pt - if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() @@ -26,7 +25,6 @@ if __name__ == "__main__": distribution=(num_classes, prototypes_per_class), transfer_function="swish_beta", transfer_beta=10.0, - # lr=0.1, proto_lr=0.1, bb_lr=0.1, input_dim=2, @@ -37,7 +35,7 @@ if __name__ == "__main__": model = pt.models.GMLVQ( hparams, optimizer=torch.optim.Adam, - prototype_initializer=pt.components.SSI(train_ds, noise=1e-2), + prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2), ) # Callbacks @@ -47,12 +45,12 @@ if __name__ == "__main__": block=False, ) pruning = pt.models.PruneLoserPrototypes( - threshold=0.02, + threshold=0.01, idle_epochs=10, prune_quota_per_epoch=5, - frequency=2, + frequency=5, replace=True, - initializer=pt.components.SSI(train_ds, noise=1e-2), + prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1), verbose=True, ) es = pl.callbacks.EarlyStopping( @@ -68,7 +66,7 @@ if __name__ == "__main__": args, callbacks=[ vis, - # es, + # es, # FIXME pruning, ], terminate_on_nan=True, diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py deleted file mode 100644 index 2bfa4c5..0000000 --- a/examples/gmlvq_iris.py +++ /dev/null @@ -1,59 +0,0 @@ -"""GLVQ example using the Iris dataset.""" - -import argparse - -import prototorch as pt -import pytorch_lightning as pl -import torch -from torch.optim.lr_scheduler import ExponentialLR - -if __name__ == "__main__": - # Command-line arguments - parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) - args = parser.parse_args() - - # Dataset - train_ds = pt.datasets.Iris() - - # Dataloaders - train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) - - # Hyperparameters - hparams = dict( - input_dim=4, - latent_dim=3, - distribution={ - "num_classes": 3, - "prototypes_per_class": 2 - }, - proto_lr=0.0005, - bb_lr=0.0005, - ) - - # Initialize the model - model = pt.models.GMLVQ( - hparams, - optimizer=torch.optim.Adam, - prototype_initializer=pt.components.SSI(train_ds), - lr_scheduler=ExponentialLR, - lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), - omega_initializer=pt.components.PCA(train_ds.data) - ) - - # Compute intermediate input and output sizes - #model.example_input_array = torch.zeros(4, 2) - - # Callbacks - vis = pt.models.VisGMLVQ2D(data=train_ds, border=0.1) - - # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=[vis], - weights_summary="full", - accelerator="ddp", - ) - - # Training loop - trainer.fit(model, train_loader) diff --git a/examples/gng_iris.py b/examples/gng_iris.py index 04d4107..7f1275d 100644 --- a/examples/gng_iris.py +++ b/examples/gng_iris.py @@ -2,11 +2,10 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch -import prototorch as pt - if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() @@ -30,7 +29,7 @@ if __name__ == "__main__": # Initialize the model model = pt.models.GrowingNeuralGas( hparams, - prototype_initializer=pt.components.Zeros(2), + prototypes_initializer=pt.initializers.ZCI(2), ) # Compute intermediate input and output sizes diff --git a/examples/ksom_colors.py b/examples/ksom_colors.py index f89a144..eee4a04 100644 --- a/examples/ksom_colors.py +++ b/examples/ksom_colors.py @@ -2,25 +2,11 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch from matplotlib import pyplot as plt - -import prototorch as pt - - -def hex_to_rgb(hex_values): - for v in hex_values: - v = v.lstrip('#') - lv = len(v) - c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)] - yield c - - -def rgb_to_hex(rgb_values): - for v in rgb_values: - c = "%02x%02x%02x" % tuple(v) - yield c +from prototorch.utils.colors import hex_to_rgb class Vis2DColorSOM(pl.Callback): @@ -93,7 +79,7 @@ if __name__ == "__main__": # Initialize the model model = pt.models.KohonenSOM( hparams, - prototype_initializer=pt.components.Random(3), + prototypes_initializer=pt.initializers.RNCI(3), ) # Compute intermediate input and output sizes diff --git a/examples/lgmlvq_moons.py b/examples/lgmlvq_moons.py index 1782045..cbdc9b4 100644 --- a/examples/lgmlvq_moons.py +++ b/examples/lgmlvq_moons.py @@ -2,23 +2,22 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch -import prototorch as pt - if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() - # Dataset - train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42) - # Reproducibility pl.utilities.seed.seed_everything(seed=2) + # Dataset + train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42) + # Dataloaders train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256, @@ -32,8 +31,10 @@ if __name__ == "__main__": ) # Initialize the model - model = pt.models.LGMLVQ(hparams, - prototype_initializer=pt.components.SMI(train_ds)) + model = pt.models.LGMLVQ( + hparams, + prototypes_initializer=pt.initializers.SMCI(train_ds), + ) # Compute intermediate input and output sizes model.example_input_array = torch.zeros(4, 2) diff --git a/examples/liramlvq_tecator.py b/examples/liramlvq_tecator.py index e0a97f8..1d95546 100644 --- a/examples/liramlvq_tecator.py +++ b/examples/liramlvq_tecator.py @@ -3,11 +3,10 @@ import argparse import matplotlib.pyplot as plt +import prototorch as pt import pytorch_lightning as pl import torch -import prototorch as pt - def plot_matrix(matrix): title = "Lambda matrix" @@ -40,20 +39,19 @@ if __name__ == "__main__": hparams = dict( distribution={ "num_classes": 2, - "prototypes_per_class": 1 + "per_class": 1, }, input_dim=100, latent_dim=2, - proto_lr=0.0001, - bb_lr=0.0001, + proto_lr=0.001, + bb_lr=0.001, ) # Initialize the model model = pt.models.SiameseGMLVQ( hparams, - # optimizer=torch.optim.SGD, optimizer=torch.optim.Adam, - prototype_initializer=pt.components.SMI(train_ds), + prototypes_initializer=pt.initializers.SMCI(train_ds), ) # Summary diff --git a/examples/lvqmln_iris.py b/examples/lvqmln_iris.py index c688788..79df874 100644 --- a/examples/lvqmln_iris.py +++ b/examples/lvqmln_iris.py @@ -2,11 +2,10 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch -import prototorch as pt - class Backbone(torch.nn.Module): def __init__(self, input_size=4, hidden_size=10, latent_size=2): @@ -41,7 +40,7 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( - distribution=[1, 2, 2], + distribution=[3, 4, 5], proto_lr=0.001, bb_lr=0.001, ) @@ -52,7 +51,10 @@ if __name__ == "__main__": # Initialize the model model = pt.models.LVQMLN( hparams, - prototype_initializer=pt.components.SSI(train_ds, transform=backbone), + prototypes_initializer=pt.initializers.SSCI( + train_ds, + transform=backbone, + ), backbone=backbone, ) @@ -67,11 +69,21 @@ if __name__ == "__main__": resolution=500, axis_off=True, ) + pruning = pt.models.PruneLoserPrototypes( + threshold=0.01, + idle_epochs=20, + prune_quota_per_epoch=2, + frequency=10, + verbose=True, + ) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - callbacks=[vis], + callbacks=[ + vis, + pruning, + ], ) # Training loop diff --git a/examples/rslvq_iris.py b/examples/rslvq_iris.py index 735d219..c7d3961 100644 --- a/examples/rslvq_iris.py +++ b/examples/rslvq_iris.py @@ -2,11 +2,9 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch -from torchvision.transforms import Lambda - -import prototorch as pt if __name__ == "__main__": # Command-line arguments @@ -28,19 +26,17 @@ if __name__ == "__main__": distribution=[2, 2, 3], proto_lr=0.05, lambd=0.1, + variance=1.0, input_dim=2, latent_dim=2, bb_lr=0.01, ) # Initialize the model - model = pt.models.probabilistic.PLVQ( + model = pt.models.RSLVQ( hparams, optimizer=torch.optim.Adam, - # prototype_initializer=pt.components.SMI(train_ds), - prototype_initializer=pt.components.SSI(train_ds, noise=0.2), - # prototype_initializer=pt.components.Zeros(2), - # prototype_initializer=pt.components.Ones(2, scale=2.0), + prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2), ) # Compute intermediate input and output sizes @@ -50,7 +46,7 @@ if __name__ == "__main__": print(model) # Callbacks - vis = pt.models.VisSiameseGLVQ2D(data=train_ds) + vis = pt.models.VisGLVQ2D(data=train_ds) # Setup trainer trainer = pl.Trainer.from_argparse_args( diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index cdd279d..9ca9d07 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -2,11 +2,10 @@ import argparse +import prototorch as pt import pytorch_lightning as pl import torch -import prototorch as pt - class Backbone(torch.nn.Module): def __init__(self, input_size=4, hidden_size=10, latent_size=2): @@ -52,7 +51,7 @@ if __name__ == "__main__": # Initialize the model model = pt.models.SiameseGLVQ( hparams, - prototype_initializer=pt.components.SMI(train_ds), + prototypes_initializer=pt.initializers.SMCI(train_ds), backbone=backbone, both_path_gradients=False, ) diff --git a/examples/warm_starting.py b/examples/warm_starting.py new file mode 100644 index 0000000..1a966f3 --- /dev/null +++ b/examples/warm_starting.py @@ -0,0 +1,84 @@ +"""Warm-starting GLVQ with prototypes from Growing Neural Gas.""" + +import argparse + +import prototorch as pt +import pytorch_lightning as pl +import torch +from torch.optim.lr_scheduler import ExponentialLR + +if __name__ == "__main__": + # Command-line arguments + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # Prepare the data + train_ds = pt.datasets.Iris(dims=[0, 2]) + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) + + # Initialize the gng + gng = pt.models.GrowingNeuralGas( + hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1), + prototypes_initializer=pt.initializers.ZCI(2), + lr_scheduler=ExponentialLR, + lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), + ) + + # Callbacks + es = pl.callbacks.EarlyStopping( + monitor="loss", + min_delta=0.001, + patience=20, + mode="min", + verbose=False, + check_on_train_epoch_end=True, + ) + + # Setup trainer for GNG + trainer = pl.Trainer( + max_epochs=200, + callbacks=[es], + weights_summary=None, + ) + + # Training loop + trainer.fit(gng, train_loader) + + # Hyperparameters + hparams = dict( + distribution=[], + lr=0.01, + ) + + # Warm-start prototypes + knn = pt.models.KNN(dict(k=1), data=train_ds) + prototypes = gng.prototypes + plabels = knn.predict(prototypes) + + # Initialize the model + model = pt.models.GLVQ( + hparams, + optimizer=torch.optim.Adam, + prototypes_initializer=pt.initializers.LCI(prototypes), + labels_initializer=pt.initializers.LLI(plabels), + lr_scheduler=ExponentialLR, + lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), + ) + + # Compute intermediate input and output sizes + model.example_input_array = torch.zeros(4, 2) + + # Callbacks + vis = pt.models.VisGLVQ2D(data=train_ds) + + # Setup trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=[vis], + weights_summary="full", + accelerator="ddp", + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index 05f0672..420a580 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -4,8 +4,19 @@ from importlib.metadata import PackageNotFoundError, version from .callbacks import PrototypeConvergence, PruneLoserPrototypes from .cbc import CBC, ImageCBC -from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN, - ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ) +from .glvq import ( + GLVQ, + GLVQ1, + GLVQ21, + GMLVQ, + GRLVQ, + LGMLVQ, + LVQMLN, + ImageGLVQ, + ImageGMLVQ, + SiameseGLVQ, + SiameseGMLVQ, +) from .knn import KNN from .lvq import LVQ1, LVQ21, MedianLVQ from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index 89d322b..d55bd08 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -5,9 +5,13 @@ from typing import Final, final import pytorch_lightning as pl import torch import torchmetrics -from prototorch.components import Components, LabeledComponents -from prototorch.functions.distances import euclidean_distance -from prototorch.modules import WTAC, LambdaLayer + +from ..core.competitions import WTAC +from ..core.components import Components, LabeledComponents +from ..core.distances import euclidean_distance +from ..core.initializers import LabelsInitializer +from ..core.pooling import stratified_min_pooling +from ..nn.wrappers import LambdaLayer class ProtoTorchMixin(object): @@ -85,13 +89,11 @@ class UnsupervisedPrototypeModel(PrototypeModel): super().__init__(hparams, **kwargs) # Layers - prototype_initializer = kwargs.get("prototype_initializer", None) - initialized_prototypes = kwargs.get("initialized_prototypes", None) - if prototype_initializer is not None or initialized_prototypes is not None: + prototypes_initializer = kwargs.get("prototypes_initializer", None) + if prototypes_initializer is not None: self.proto_layer = Components( self.hparams.num_prototypes, - initializer=prototype_initializer, - initialized_components=initialized_prototypes, + initializer=prototypes_initializer, ) def compute_distances(self, x): @@ -109,23 +111,24 @@ class SupervisedPrototypeModel(PrototypeModel): super().__init__(hparams, **kwargs) # Layers - prototype_initializer = kwargs.get("prototype_initializer", None) - initialized_prototypes = kwargs.get("initialized_prototypes", None) - if prototype_initializer is not None or initialized_prototypes is not None: + prototypes_initializer = kwargs.get("prototypes_initializer", None) + labels_initializer = kwargs.get("labels_initializer", + LabelsInitializer()) + if prototypes_initializer is not None: self.proto_layer = LabeledComponents( distribution=self.hparams.distribution, - initializer=prototype_initializer, - initialized_components=initialized_prototypes, + components_initializer=prototypes_initializer, + labels_initializer=labels_initializer, ) self.competition_layer = WTAC() @property def prototype_labels(self): - return self.proto_layer.component_labels.detach().cpu() + return self.proto_layer.labels.detach().cpu() @property def num_classes(self): - return len(self.proto_layer.distribution) + return self.proto_layer.num_classes def compute_distances(self, x): protos, _ = self.proto_layer() @@ -134,15 +137,14 @@ class SupervisedPrototypeModel(PrototypeModel): def forward(self, x): distances = self.compute_distances(x) - y_pred = self.predict_from_distances(distances) - # TODO - y_pred = torch.eye(self.num_classes, device=self.device)[ - y_pred.long()] # depends on labels {0,...,num_classes} + plabels = self.proto_layer.labels + winning = stratified_min_pooling(distances, plabels) + y_pred = torch.nn.functional.softmin(winning) return y_pred def predict_from_distances(self, distances): with torch.no_grad(): - plabels = self.proto_layer.component_labels + plabels = self.proto_layer.labels y_pred = self.competition_layer(distances, plabels) return y_pred diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index d088c7d..62b628a 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -4,8 +4,9 @@ import logging import pytorch_lightning as pl import torch -from prototorch.components import Components +from ..core.components import Components +from ..core.initializers import LiteralCompInitializer from .extras import ConnectionTopology @@ -16,7 +17,7 @@ class PruneLoserPrototypes(pl.Callback): prune_quota_per_epoch=-1, frequency=1, replace=False, - initializer=None, + prototypes_initializer=None, verbose=False): self.threshold = threshold # minimum win ratio self.idle_epochs = idle_epochs # epochs to wait before pruning @@ -24,7 +25,7 @@ class PruneLoserPrototypes(pl.Callback): self.frequency = frequency self.replace = replace self.verbose = verbose - self.initializer = initializer + self.prototypes_initializer = prototypes_initializer def on_epoch_end(self, trainer, pl_module): if (trainer.current_epoch + 1) < self.idle_epochs: @@ -55,8 +56,9 @@ class PruneLoserPrototypes(pl.Callback): if self.verbose: print(f"Re-adding pruned prototypes...") print(f"{distribution=}") - pl_module.add_prototypes(distribution=distribution, - initializer=self.initializer) + pl_module.add_prototypes( + distribution=distribution, + components_initializer=self.prototypes_initializer) new_num_protos = pl_module.num_prototypes if self.verbose: print(f"`num_prototypes` changed from {cur_num_protos} " @@ -116,7 +118,8 @@ class GNGCallback(pl.Callback): # Add component pl_module.proto_layer.add_components( - initialized_components=new_component.unsqueeze(0)) + None, + initializer=LiteralCompInitializer(new_component.unsqueeze(0))) # Adjust Topology topology.add_prototype() diff --git a/prototorch/models/cbc.py b/prototorch/models/cbc.py index 2a12b12..430ae6c 100644 --- a/prototorch/models/cbc.py +++ b/prototorch/models/cbc.py @@ -1,49 +1,54 @@ import torch import torchmetrics +from ..core.competitions import CBCC +from ..core.components import ReasoningComponents +from ..core.initializers import RandomReasoningsInitializer +from ..core.losses import MarginLoss +from ..core.similarities import euclidean_similarity +from ..nn.wrappers import LambdaLayer from .abstract import ImagePrototypesMixin -from .extras import (CosineSimilarity, MarginLoss, ReasoningLayer, - euclidean_similarity, rescaled_cosine_similarity, - shift_activation) from .glvq import SiameseGLVQ class CBC(SiameseGLVQ): """Classification-By-Components.""" - def __init__(self, hparams, margin=0.1, **kwargs): + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) - self.margin = margin - self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity) - num_components = self.components.shape[0] - self.reasoning_layer = ReasoningLayer(num_components=num_components, - num_classes=self.num_classes) - self.component_layer = self.proto_layer - @property - def components(self): - return self.prototypes + similarity_fn = kwargs.get("similarity_fn", euclidean_similarity) + components_initializer = kwargs.get("components_initializer", None) + reasonings_initializer = kwargs.get("reasonings_initializer", + RandomReasoningsInitializer()) + self.components_layer = ReasoningComponents( + self.hparams.distribution, + components_initializer=components_initializer, + reasonings_initializer=reasonings_initializer, + ) + self.similarity_layer = LambdaLayer(similarity_fn) + self.competition_layer = CBCC() - @property - def reasonings(self): - return self.reasoning_layer.reasonings.cpu() + # Namespace hook + self.proto_layer = self.components_layer + + self.loss = MarginLoss(self.hparams.margin) def forward(self, x): - components, _ = self.component_layer() + components, reasonings = self.components_layer() latent_x = self.backbone(x) self.backbone.requires_grad_(self.both_path_gradients) latent_components = self.backbone(components) self.backbone.requires_grad_(True) - detections = self.similarity_fn(latent_x, latent_components) - probs = self.reasoning_layer(detections) + detections = self.similarity_layer(latent_x, latent_components) + probs = self.competition_layer(detections, reasonings) return probs def shared_step(self, batch, batch_idx, optimizer_idx=None): x, y = batch - # x = x.view(x.size(0), -1) y_pred = self(x) - num_classes = self.reasoning_layer.num_classes + num_classes = self.num_classes y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes) - loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0) + loss = self.loss(y_pred, y_true).mean(dim=0) return y_pred, loss def training_step(self, batch, batch_idx, optimizer_idx=None): @@ -70,7 +75,3 @@ class ImageCBC(ImagePrototypesMixin, CBC): """CBC model that constrains the components to the range [0, 1] by clamping after updates. """ - def __init__(self, hparams, **kwargs): - super().__init__(hparams, **kwargs) - # Namespace hook - self.proto_layer = self.component_layer diff --git a/prototorch/models/extras.py b/prototorch/models/extras.py index 60bd158..644d0f1 100644 --- a/prototorch/models/extras.py +++ b/prototorch/models/extras.py @@ -5,23 +5,32 @@ Modules not yet available in prototorch go here temporarily. """ import torch -from prototorch.functions.distances import euclidean_distance -from prototorch.functions.similarities import cosine_similarity + +from ..core.similarities import gaussian -def rescaled_cosine_similarity(x, y): - """Cosine Similarity rescaled to [0, 1].""" - similarities = cosine_similarity(x, y) - return (similarities + 1.0) / 2.0 +def rank_scaled_gaussian(distances, lambd): + order = torch.argsort(distances, dim=1) + ranks = torch.argsort(order, dim=1) + return torch.exp(-torch.exp(-ranks / lambd) * distances) -def shift_activation(x): - return (x + 1.0) / 2.0 +class GaussianPrior(torch.nn.Module): + def __init__(self, variance): + super().__init__() + self.variance = variance + + def forward(self, distances): + return gaussian(distances, self.variance) -def euclidean_similarity(x, y, variance=1.0): - d = euclidean_distance(x, y) - return torch.exp(-(d * d) / (2 * variance)) +class RankScaledGaussianPrior(torch.nn.Module): + def __init__(self, lambd): + super().__init__() + self.lambd = lambd + + def forward(self, distances): + return rank_scaled_gaussian(distances, self.lambd) class ConnectionTopology(torch.nn.Module): @@ -79,64 +88,3 @@ class ConnectionTopology(torch.nn.Module): def extra_repr(self): return f"(agelimit): ({self.agelimit})" - - -class CosineSimilarity(torch.nn.Module): - def __init__(self, activation=shift_activation): - super().__init__() - self.activation = activation - - def forward(self, x, y): - epsilon = torch.finfo(x.dtype).eps - normed_x = (x / x.pow(2).sum(dim=tuple(range( - 1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten( - start_dim=1) - normed_y = (y / y.pow(2).sum(dim=tuple(range( - 1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten( - start_dim=1) - # normed_x = (x / torch.linalg.norm(x, dim=1)) - diss = torch.inner(normed_x, normed_y) - return self.activation(diss) - - -class MarginLoss(torch.nn.modules.loss._Loss): - def __init__(self, - margin=0.3, - size_average=None, - reduce=None, - reduction="mean"): - super().__init__(size_average, reduce, reduction) - self.margin = margin - - def forward(self, input_, target): - dp = torch.sum(target * input_, dim=-1) - dm = torch.max(input_ - target, dim=-1).values - return torch.nn.functional.relu(dm - dp + self.margin) - - -class ReasoningLayer(torch.nn.Module): - def __init__(self, num_components, num_classes, num_replicas=1): - super().__init__() - self.num_replicas = num_replicas - self.num_classes = num_classes - probabilities_init = torch.zeros(2, 1, num_components, - self.num_classes) - probabilities_init.uniform_(0.4, 0.6) - # TODO Use `self.register_parameter("param", Paramater(param))` instead - self.reasoning_probabilities = torch.nn.Parameter(probabilities_init) - - @property - def reasonings(self): - pk = self.reasoning_probabilities[0] - nk = (1 - pk) * self.reasoning_probabilities[1] - ik = 1 - pk - nk - img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2) - return img.unsqueeze(1) - - def forward(self, detections): - pk = self.reasoning_probabilities[0].clamp(0, 1) - nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1) - numerator = (detections @ (pk - nk)) + nk.sum(1) - probs = numerator / (pk + nk).sum(1) - probs = probs.squeeze(0) - return probs diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index ec80e1e..cc8ed57 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -1,16 +1,14 @@ """Models based on the GLVQ framework.""" import torch -from prototorch.functions.activations import get_activation -from prototorch.functions.competitions import wtac -from prototorch.functions.distances import (lomega_distance, omega_distance, - squared_euclidean_distance) -from prototorch.functions.helper import get_flat -from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss -from prototorch.components import LinearMapping -from prototorch.modules import LambdaLayer, LossLayer from torch.nn.parameter import Parameter +from ..core.competitions import wtac +from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance +from ..core.initializers import EyeTransformInitializer +from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss +from ..nn.activations import get_activation +from ..nn.wrappers import LambdaLayer, LossLayer from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel @@ -30,9 +28,6 @@ class GLVQ(SupervisedPrototypeModel): # Loss self.loss = LossLayer(glvq_loss) - # Prototype metrics - self.initialize_prototype_win_ratios() - def initialize_prototype_win_ratios(self): self.register_buffer( "prototype_win_ratios", @@ -59,7 +54,7 @@ class GLVQ(SupervisedPrototypeModel): def shared_step(self, batch, batch_idx, optimizer_idx=None): x, y = batch out = self.compute_distances(x) - plabels = self.proto_layer.component_labels + plabels = self.proto_layer.labels mu = self.loss(out, y, prototype_labels=plabels) batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta) loss = batch_loss.sum(dim=0) @@ -135,7 +130,7 @@ class SiameseGLVQ(GLVQ): def compute_distances(self, x): protos, _ = self.proto_layer() - x, protos = get_flat(x, protos) + x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)] latent_x = self.backbone(x) self.backbone.requires_grad_(self.both_path_gradients) latent_protos = self.backbone(protos) @@ -240,18 +235,14 @@ class GMLVQ(GLVQ): super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - omega_initializer = kwargs.get("omega_initializer", None) - initialized_omega = kwargs.get("initialized_omega", None) - if omega_initializer is not None or initialized_omega is not None: - self.omega_layer = LinearMapping( - mapping_shape=(self.hparams.input_dim, self.hparams.latent_dim), - initializer=omega_initializer, - initialized_linearmapping=initialized_omega, - ) + omega_initializer = kwargs.get("omega_initializer", + EyeTransformInitializer()) + omega = omega_initializer.generate(self.hparams.input_dim, + self.hparams.latent_dim) + self.register_parameter("_omega", Parameter(omega)) + self.backbone = LambdaLayer(lambda x: x @ self._omega, + name="omega matrix") - self.register_parameter("_omega", Parameter(self.omega_layer.mapping)) - self.backbone = LambdaLayer(lambda x: x @ self._omega, name = "omega matrix") - @property def omega_matrix(self): return self._omega.detach().cpu() @@ -264,24 +255,6 @@ class GMLVQ(GLVQ): def extra_repr(self): return f"(omega): (shape: {tuple(self._omega.shape)})" - def predict_latent(self, x, map_protos=True): - """Predict `x` assuming it is already embedded in the latent space. - - Only the prototypes are embedded in the latent space using the - backbone. - - """ - self.eval() - with torch.no_grad(): - protos, plabels = self.proto_layer() - if map_protos: - protos = self.backbone(protos) - d = squared_euclidean_distance(x, protos) - y_pred = wtac(d, plabels) - return y_pred - - - class LGMLVQ(GMLVQ): """Localized and Generalized Matrix Learning Vector Quantization.""" diff --git a/prototorch/models/knn.py b/prototorch/models/knn.py index c23a276..0886550 100644 --- a/prototorch/models/knn.py +++ b/prototorch/models/knn.py @@ -2,9 +2,10 @@ import warnings -from prototorch.components import LabeledComponents -from prototorch.modules import KNNC - +from ..core.competitions import KNNC +from ..core.components import LabeledComponents +from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer +from ..utils.utils import parse_data_arg from .abstract import SupervisedPrototypeModel @@ -19,9 +20,13 @@ class KNN(SupervisedPrototypeModel): data = kwargs.get("data", None) if data is None: raise ValueError("KNN requires data, but was not provided!") + data, targets = parse_data_arg(data) # Layers - self.proto_layer = LabeledComponents(initialized_components=data) + self.proto_layer = LabeledComponents( + distribution=[], + components_initializer=LiteralCompInitializer(data), + labels_initializer=LiteralLabelsInitializer(targets)) self.competition_layer = KNNC(k=self.hparams.k) def training_step(self, train_batch, batch_idx, optimizer_idx=None): diff --git a/prototorch/models/lvq.py b/prototorch/models/lvq.py index d3f60d4..61c946f 100644 --- a/prototorch/models/lvq.py +++ b/prototorch/models/lvq.py @@ -1,7 +1,6 @@ """LVQ models that are optimized using non-gradient methods.""" -from prototorch.functions.losses import _get_dp_dm - +from ..core.losses import _get_dp_dm from .abstract import NonGradientMixin from .glvq import GLVQ @@ -10,7 +9,7 @@ class LVQ1(NonGradientMixin, GLVQ): """Learning Vector Quantization 1.""" def training_step(self, train_batch, batch_idx, optimizer_idx=None): protos = self.proto_layer.components - plabels = self.proto_layer.component_labels + plabels = self.proto_layer.labels x, y = train_batch dis = self.compute_distances(x) @@ -29,6 +28,8 @@ class LVQ1(NonGradientMixin, GLVQ): self.proto_layer.load_state_dict({"_components": updated_protos}, strict=False) + print(f"{dis=}") + print(f"{y=}") # Logging self.log_acc(dis, y, tag="train_acc") @@ -39,7 +40,7 @@ class LVQ21(NonGradientMixin, GLVQ): """Learning Vector Quantization 2.1.""" def training_step(self, train_batch, batch_idx, optimizer_idx=None): protos = self.proto_layer.components - plabels = self.proto_layer.component_labels + plabels = self.proto_layer.labels x, y = train_batch dis = self.compute_distances(x) diff --git a/prototorch/models/probabilistic.py b/prototorch/models/probabilistic.py index b1771c5..d7276f3 100644 --- a/prototorch/models/probabilistic.py +++ b/prototorch/models/probabilistic.py @@ -1,13 +1,11 @@ """Probabilistic GLVQ methods""" import torch -from prototorch.functions.losses import nllr_loss, rslvq_loss -from prototorch.functions.pooling import (stratified_min_pooling, - stratified_sum_pooling) -from prototorch.functions.transforms import (GaussianPrior, - RankScaledGaussianPrior) -from prototorch.modules import LambdaLayer, LossLayer +from ..core.losses import nllr_loss, rslvq_loss +from ..core.pooling import stratified_min_pooling, stratified_sum_pooling +from ..nn.wrappers import LambdaLayer, LossLayer +from .extras import GaussianPrior, RankScaledGaussianPrior from .glvq import GLVQ, SiameseGMLVQ @@ -22,7 +20,7 @@ class CELVQ(GLVQ): def shared_step(self, batch, batch_idx, optimizer_idx=None): x, y = batch out = self.compute_distances(x) # [None, num_protos] - plabels = self.proto_layer.component_labels + plabels = self.proto_layer.labels winning = stratified_min_pooling(out, plabels) # [None, num_classes] probs = -1.0 * winning batch_loss = self.loss(probs, y.long()) @@ -56,7 +54,7 @@ class ProbabilisticLVQ(GLVQ): def training_step(self, batch, batch_idx, optimizer_idx=None): x, y = batch out = self.forward(x) - plabels = self.proto_layer.component_labels + plabels = self.proto_layer.labels batch_loss = self.loss(out, y, plabels) loss = batch_loss.sum(dim=0) return loss @@ -89,11 +87,10 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ): self.hparams.lambd) self.loss = torch.nn.KLDivLoss() - def training_step(self, batch, batch_idx, optimizer_idx=None): - x, y = batch - out = self.forward(x) - y_dist = torch.nn.functional.one_hot( - y.long(), num_classes=self.num_classes).float() - batch_loss = self.loss(out, y_dist) - loss = batch_loss.sum(dim=0) - return loss + # FIXME + # def training_step(self, batch, batch_idx, optimizer_idx=None): + # x, y = batch + # y_pred = self(x) + # batch_loss = self.loss(y_pred, y) + # loss = batch_loss.sum(dim=0) + # return loss diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index d171115..ba2c80d 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -2,11 +2,11 @@ import numpy as np import torch -from prototorch.functions.competitions import wtac -from prototorch.functions.distances import squared_euclidean_distance -from prototorch.modules import LambdaLayer -from prototorch.modules.losses import NeuralGasEnergy +from ..core.competitions import wtac +from ..core.distances import squared_euclidean_distance +from ..core.losses import NeuralGasEnergy +from ..nn.wrappers import LambdaLayer from .abstract import NonGradientMixin, UnsupervisedPrototypeModel from .callbacks import GNGCallback from .extras import ConnectionTopology diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 68e53d4..4f6b696 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -7,6 +7,8 @@ import torchvision from matplotlib import pyplot as plt from torch.utils.data import DataLoader, Dataset +from ..utils.utils import mesh2d + class Vis2DAbstract(pl.Callback): def __init__(self, @@ -73,23 +75,7 @@ class Vis2DAbstract(pl.Callback): ax.axis("off") return ax - def get_mesh_input(self, x): - x_shift = self.border * np.ptp(x[:, 0]) - y_shift = self.border * np.ptp(x[:, 1]) - x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift - y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift - xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution), - np.linspace(y_min, y_max, self.resolution)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - return mesh_input, xx, yy - - def perform_pca_2D(self, data): - (_, eigVal, eigVec) = torch.pca_lowrank(data, q=2) - return data @ eigVec - - def plot_data(self, ax, x, y, pca=False): - if pca: - x = self.perform_pca_2D(x) + def plot_data(self, ax, x, y): ax.scatter( x[:, 0], x[:, 1], @@ -100,9 +86,7 @@ class Vis2DAbstract(pl.Callback): s=30, ) - def plot_protos(self, ax, protos, plabels, pca=False): - if pca: - protos = self.perform_pca_2D(protos) + def plot_protos(self, ax, protos, plabels): ax.scatter( protos[:, 0], protos[:, 1], @@ -146,7 +130,7 @@ class VisGLVQ2D(Vis2DAbstract): self.plot_data(ax, x_train, y_train) self.plot_protos(ax, protos, plabels) x = np.vstack((x_train, protos)) - mesh_input, xx, yy = self.get_mesh_input(x) + mesh_input, xx, yy = mesh2d(x, self.border, self.resolution) _components = pl_module.proto_layer._components mesh_input = torch.from_numpy(mesh_input).type_as(_components) y_pred = pl_module.predict(mesh_input) @@ -181,9 +165,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract): if self.show_protos: self.plot_protos(ax, protos, plabels) x = np.vstack((x_train, protos)) - mesh_input, xx, yy = self.get_mesh_input(x) + mesh_input, xx, yy = mesh2d(x, self.border, self.resolution) else: - mesh_input, xx, yy = self.get_mesh_input(x_train) + mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution) _components = pl_module.proto_layer._components mesh_input = torch.Tensor(mesh_input).type_as(_components) y_pred = pl_module.predict_latent(mesh_input, @@ -194,50 +178,6 @@ class VisSiameseGLVQ2D(Vis2DAbstract): self.log_and_display(trainer, pl_module) -class VisGMLVQ2D(Vis2DAbstract): - def __init__(self, *args, map_protos=True, **kwargs): - super().__init__(*args, **kwargs) - self.map_protos = map_protos - - def on_epoch_end(self, trainer, pl_module): - if not self.precheck(trainer): - return True - - protos = pl_module.prototypes - plabels = pl_module.prototype_labels - x_train, y_train = self.x_train, self.y_train - device = pl_module.device - with torch.no_grad(): - x_train = pl_module.backbone(torch.Tensor(x_train).to(device)) - x_train = x_train.cpu().detach() - if self.map_protos: - with torch.no_grad(): - protos = pl_module.backbone(torch.Tensor(protos).to(device)) - protos = protos.cpu().detach() - ax = self.setup_ax() - if x_train.shape[1] > 2: - self.plot_data(ax, x_train, y_train, pca=True) - else: - self.plot_data(ax, x_train, y_train, pca=False) - if self.show_protos: - if protos.shape[1] > 2: - self.plot_protos(ax, protos, plabels, pca=True) - else: - self.plot_protos(ax, protos, plabels, pca=False) - ### something to work on: meshgrid with pca - # x = np.vstack((x_train, protos)) - # mesh_input, xx, yy = self.get_mesh_input(x) - #else: - # mesh_input, xx, yy = self.get_mesh_input(x_train) - #_components = pl_module.proto_layer._components - #mesh_input = torch.Tensor(mesh_input).type_as(_components) - #y_pred = pl_module.predict_latent(mesh_input, - # map_protos=self.map_protos) - #y_pred = y_pred.cpu().reshape(xx.shape) - #ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - self.log_and_display(trainer, pl_module) - - class VisCBC2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): if not self.precheck(trainer): @@ -250,8 +190,8 @@ class VisCBC2D(Vis2DAbstract): self.plot_data(ax, x_train, y_train) self.plot_protos(ax, protos, "w") x = np.vstack((x_train, protos)) - mesh_input, xx, yy = self.get_mesh_input(x) - _components = pl_module.component_layer._components + mesh_input, xx, yy = mesh2d(x, self.border, self.resolution) + _components = pl_module.components_layer._components y_pred = pl_module.predict( torch.Tensor(mesh_input).type_as(_components)) y_pred = y_pred.cpu().reshape(xx.shape) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..24eeb0b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,8 @@ +[isort] +profile = hug +src_paths = isort, test + +[yapf] +based_on_style = pep8 +spaces_before_comment = 2 +split_before_logical_operator = true