diff --git a/.codacy.yml b/.codacy.yml deleted file mode 100644 index 2e7468d..0000000 --- a/.codacy.yml +++ /dev/null @@ -1,15 +0,0 @@ -# To validate the contents of your configuration file -# run the following command in the folder where the configuration file is located: -# codacy-analysis-cli validate-configuration --directory `pwd` -# To analyse, run: -# codacy-analysis-cli analyse --tool remark-lint --directory `pwd` ---- -engines: - pylintpython3: - exclude_paths: - - config/engines.yml - remark-lint: - exclude_paths: - - config/engines.yml -exclude_paths: - - 'tests/**' diff --git a/.codecov.yml b/.codecov.yml deleted file mode 100644 index cbf6b65..0000000 --- a/.codecov.yml +++ /dev/null @@ -1,2 +0,0 @@ -comment: - require_changes: yes diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml new file mode 100644 index 0000000..8357dd5 --- /dev/null +++ b/.github/workflows/examples.yml @@ -0,0 +1,25 @@ +# Thi workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: examples + +on: + push: + paths: + - 'examples/**.py' +jobs: + cpu: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[all] + - name: Run examples + run: | + ./tests/test_examples.sh examples/ diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml new file mode 100644 index 0000000..d94683b --- /dev/null +++ b/.github/workflows/pythonapp.yml @@ -0,0 +1,73 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: tests + +on: + push: + pull_request: + branches: [ master ] + +jobs: + style: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[all] + - uses: pre-commit/action@v2.0.3 + compatibility: + needs: style + strategy: + fail-fast: false + matrix: + python-version: ["3.7", "3.8", "3.9"] + os: [ubuntu-latest, windows-latest] + exclude: + - os: windows-latest + python-version: "3.7" + - os: windows-latest + python-version: "3.8" + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[all] + - name: Test with pytest + run: | + pytest + publish_pypi: + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + needs: compatibility + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v2 + with: + python-version: "3.9" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[all] + pip install wheel + - name: Build package + run: python setup.py sdist bdist_wheel + - name: Publish a Python distribution to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 001227c..94784d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.1.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -18,19 +18,19 @@ repos: - id: autoflake - repo: http://github.com/PyCQA/isort - rev: 5.8.0 + rev: 5.10.1 hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.902 + rev: v0.931 hooks: - id: mypy files: prototorch additional_dependencies: [types-pkg_resources] - repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.31.0 + rev: v0.32.0 hooks: - id: yapf @@ -42,7 +42,7 @@ repos: - id: python-check-blanket-noqa - repo: https://github.com/asottile/pyupgrade - rev: v2.19.4 + rev: v2.31.0 hooks: - id: pyupgrade diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index de0b7d6..0000000 --- a/.travis.yml +++ /dev/null @@ -1,44 +0,0 @@ -dist: bionic -sudo: false -language: python -python: - - 3.9 - - 3.8 - - 3.7 - - 3.6 -cache: - directories: - - "$HOME/.cache/pip" - - "./tests/artifacts" - - "$HOME/datasets" -install: -- pip install git+git://github.com/si-cim/prototorch@dev --progress-bar off -- pip install .[all] --progress-bar off -script: -- coverage run -m pytest -- ./tests/test_examples.sh examples/ -after_success: -- bash <(curl -s https://codecov.io/bash) - -# Publish on PyPI -jobs: - include: - - stage: build - python: 3.9 - script: echo "Starting Pypi build" - deploy: - provider: pypi - username: __token__ - distributions: "sdist bdist_wheel" - password: - secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg= - on: - tags: true - skip_existing: true - -# The password is encrypted with: -# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password` -# See https://docs.travis-ci.com/user/deployment/pypi and -# https://github.com/travis-ci/travis.rb#installation -# for more details -# Note: The encrypt command does not work well in ZSH. diff --git a/README.md b/README.md index 7cb728f..ec72c57 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ # ProtoTorch Models -[![Build Status](https://api.travis-ci.com/si-cim/prototorch_models.svg?branch=main)](https://travis-ci.com/github/si-cim/prototorch_models) [![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch_models?color=yellow&label=version)](https://github.com/si-cim/prototorch_models/releases) [![PyPI](https://img.shields.io/pypi/v/prototorch_models)](https://pypi.org/project/prototorch_models/) [![GitHub license](https://img.shields.io/github/license/si-cim/prototorch_models)](https://github.com/si-cim/prototorch_models/blob/master/LICENSE) diff --git a/examples/glvq_spiral.py b/examples/gmlvq_spiral.py similarity index 97% rename from examples/glvq_spiral.py rename to examples/gmlvq_spiral.py index 8a30212..68c7e3b 100644 --- a/examples/glvq_spiral.py +++ b/examples/gmlvq_spiral.py @@ -1,4 +1,4 @@ -"""GLVQ example using the spiral dataset.""" +"""GMLVQ example using the spiral dataset.""" import argparse diff --git a/examples/gtlvq_mnist.py b/examples/gtlvq_mnist.py new file mode 100644 index 0000000..481065a --- /dev/null +++ b/examples/gtlvq_mnist.py @@ -0,0 +1,104 @@ +"""GTLVQ example using the MNIST dataset.""" + +import argparse + +import prototorch as pt +import pytorch_lightning as pl +import torch +from torchvision import transforms +from torchvision.datasets import MNIST + +if __name__ == "__main__": + # Command-line arguments + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # Dataset + train_ds = MNIST( + "~/datasets", + train=True, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + ]), + ) + test_ds = MNIST( + "~/datasets", + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + ]), + ) + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, + num_workers=0, + batch_size=256) + test_loader = torch.utils.data.DataLoader(test_ds, + num_workers=0, + batch_size=256) + + # Hyperparameters + num_classes = 10 + prototypes_per_class = 1 + hparams = dict( + input_dim=28 * 28, + latent_dim=28, + distribution=(num_classes, prototypes_per_class), + proto_lr=0.01, + bb_lr=0.01, + ) + + # Initialize the model + model = pt.models.ImageGTLVQ( + hparams, + optimizer=torch.optim.Adam, + prototypes_initializer=pt.initializers.SMCI(train_ds), + #Use one batch of data for subspace initiator. + omega_initializer=pt.initializers.PCALinearTransformInitializer( + next(iter(train_loader))[0].reshape(256, 28 * 28))) + + # Callbacks + vis = pt.models.VisImgComp( + data=train_ds, + num_columns=10, + show=False, + tensorboard=True, + random_data=100, + add_embedding=True, + embedding_data=200, + flatten_data=False, + ) + pruning = pt.models.PruneLoserPrototypes( + threshold=0.01, + idle_epochs=1, + prune_quota_per_epoch=10, + frequency=1, + verbose=True, + ) + es = pl.callbacks.EarlyStopping( + monitor="train_loss", + min_delta=0.001, + patience=15, + mode="min", + check_on_train_epoch_end=True, + ) + + # Setup trainer + # using GPUs here is strongly recommended! + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=[ + vis, + pruning, + # es, + ], + terminate_on_nan=True, + weights_summary=None, + accelerator="ddp", + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/examples/gtlvq_moons.py b/examples/gtlvq_moons.py new file mode 100644 index 0000000..79ff32f --- /dev/null +++ b/examples/gtlvq_moons.py @@ -0,0 +1,63 @@ +"""Localized-GTLVQ example using the Moons dataset.""" + +import argparse + +import prototorch as pt +import pytorch_lightning as pl +import torch + +if __name__ == "__main__": + # Command-line arguments + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # Reproducibility + pl.utilities.seed.seed_everything(seed=2) + + # Dataset + train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42) + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, + batch_size=256, + shuffle=True) + + # Hyperparameters + # Latent_dim should be lower than input dim. + hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1) + + # Initialize the model + model = pt.models.GTLVQ( + hparams, prototypes_initializer=pt.initializers.SMCI(train_ds)) + + # Compute intermediate input and output sizes + model.example_input_array = torch.zeros(4, 2) + + # Summary + print(model) + + # Callbacks + vis = pt.models.VisGLVQ2D(data=train_ds) + es = pl.callbacks.EarlyStopping( + monitor="train_acc", + min_delta=0.001, + patience=20, + mode="max", + verbose=False, + check_on_train_epoch_end=True, + ) + + # Setup trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=[ + vis, + es, + ], + weights_summary="full", + accelerator="ddp", + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/examples/ksom_colors.py b/examples/ksom_colors.py index eee4a04..f86d3ac 100644 --- a/examples/ksom_colors.py +++ b/examples/ksom_colors.py @@ -10,6 +10,7 @@ from prototorch.utils.colors import hex_to_rgb class Vis2DColorSOM(pl.Callback): + def __init__(self, data, title="ColorSOMe", pause_time=0.1): super().__init__() self.title = title diff --git a/examples/lvqmln_iris.py b/examples/lvqmln_iris.py index 79df874..6a6023c 100644 --- a/examples/lvqmln_iris.py +++ b/examples/lvqmln_iris.py @@ -8,6 +8,7 @@ import torch class Backbone(torch.nn.Module): + def __init__(self, input_size=4, hidden_size=10, latent_size=2): super().__init__() self.input_size = input_size diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index 9ca9d07..e7a297b 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -8,6 +8,7 @@ import torch class Backbone(torch.nn.Module): + def __init__(self, input_size=4, hidden_size=10, latent_size=2): super().__init__() self.input_size = input_size diff --git a/examples/siamese_gtlvq_iris.py b/examples/siamese_gtlvq_iris.py new file mode 100644 index 0000000..455c0fb --- /dev/null +++ b/examples/siamese_gtlvq_iris.py @@ -0,0 +1,73 @@ +"""Siamese GTLVQ example using all four dimensions of the Iris dataset.""" + +import argparse + +import prototorch as pt +import pytorch_lightning as pl +import torch + + +class Backbone(torch.nn.Module): + + def __init__(self, input_size=4, hidden_size=10, latent_size=2): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.latent_size = latent_size + self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size) + self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size) + self.activation = torch.nn.Sigmoid() + + def forward(self, x): + x = self.activation(self.dense1(x)) + out = self.activation(self.dense2(x)) + return out + + +if __name__ == "__main__": + # Command-line arguments + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # Dataset + train_ds = pt.datasets.Iris() + + # Reproducibility + pl.utilities.seed.seed_everything(seed=2) + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) + + # Hyperparameters + hparams = dict(distribution=[1, 2, 3], + proto_lr=0.01, + bb_lr=0.01, + input_dim=2, + latent_dim=1) + + # Initialize the backbone + backbone = Backbone(latent_size=hparams["input_dim"]) + + # Initialize the model + model = pt.models.SiameseGTLVQ( + hparams, + prototypes_initializer=pt.initializers.SMCI(train_ds), + backbone=backbone, + both_path_gradients=False, + ) + + # Model summary + print(model) + + # Callbacks + vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1) + + # Setup trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=[vis], + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index b75d45e..999ac01 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -8,17 +8,34 @@ from .glvq import ( GLVQ21, GMLVQ, GRLVQ, + GTLVQ, LGMLVQ, LVQMLN, ImageGLVQ, ImageGMLVQ, + ImageGTLVQ, SiameseGLVQ, SiameseGMLVQ, + SiameseGTLVQ, ) from .knn import KNN -from .lvq import LVQ1, LVQ21, MedianLVQ -from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ -from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas +from .lvq import ( + LVQ1, + LVQ21, + MedianLVQ, +) +from .probabilistic import ( + CELVQ, + PLVQ, + RSLVQ, + SLVQ, +) +from .unsupervised import ( + GrowingNeuralGas, + HeskesSOM, + KohonenSOM, + NeuralGas, +) from .vis import * __version__ = "0.4.0" diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index c6c1164..4c2355e 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -14,6 +14,7 @@ from ..nn.wrappers import LambdaLayer class ProtoTorchBolt(pl.LightningModule): """All ProtoTorch models are ProtoTorch Bolts.""" + def __init__(self, hparams, **kwargs): super().__init__() @@ -52,6 +53,7 @@ class ProtoTorchBolt(pl.LightningModule): class PrototypeModel(ProtoTorchBolt): + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -81,6 +83,7 @@ class PrototypeModel(ProtoTorchBolt): class UnsupervisedPrototypeModel(PrototypeModel): + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -103,6 +106,7 @@ class UnsupervisedPrototypeModel(PrototypeModel): class SupervisedPrototypeModel(PrototypeModel): + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -135,7 +139,7 @@ class SupervisedPrototypeModel(PrototypeModel): distances = self.compute_distances(x) _, plabels = self.proto_layer() winning = stratified_min_pooling(distances, plabels) - y_pred = torch.nn.functional.softmin(winning) + y_pred = torch.nn.functional.softmin(winning, dim=1) return y_pred def predict_from_distances(self, distances): @@ -178,6 +182,7 @@ class ProtoTorchMixin(object): class NonGradientMixin(ProtoTorchMixin): """Mixin for custom non-gradient optimization.""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.automatic_optimization = False @@ -188,6 +193,7 @@ class NonGradientMixin(ProtoTorchMixin): class ImagePrototypesMixin(ProtoTorchMixin): """Mixin for models with image prototypes.""" + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): """Constrain the components to the range [0, 1] by clamping after updates.""" self.proto_layer.components.data.clamp_(0.0, 1.0) diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index 095f41d..f12d162 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -11,6 +11,7 @@ from .extras import ConnectionTopology class PruneLoserPrototypes(pl.Callback): + def __init__(self, threshold=0.01, idle_epochs=10, @@ -67,6 +68,7 @@ class PruneLoserPrototypes(pl.Callback): class PrototypeConvergence(pl.Callback): + def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False): self.min_delta = min_delta self.idle_epochs = idle_epochs # epochs to wait @@ -89,6 +91,7 @@ class GNGCallback(pl.Callback): Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke. """ + def __init__(self, reduction=0.1, freq=10): self.reduction = reduction self.freq = freq diff --git a/prototorch/models/cbc.py b/prototorch/models/cbc.py index a8cba61..8eeb554 100644 --- a/prototorch/models/cbc.py +++ b/prototorch/models/cbc.py @@ -13,6 +13,7 @@ from .glvq import SiameseGLVQ class CBC(SiameseGLVQ): """Classification-By-Components.""" + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) diff --git a/prototorch/models/extras.py b/prototorch/models/extras.py index 644d0f1..8a230d3 100644 --- a/prototorch/models/extras.py +++ b/prototorch/models/extras.py @@ -15,7 +15,46 @@ def rank_scaled_gaussian(distances, lambd): return torch.exp(-torch.exp(-ranks / lambd) * distances) +def orthogonalization(tensors): + """Orthogonalization via polar decomposition """ + u, _, v = torch.svd(tensors, compute_uv=True) + u_shape = tuple(list(u.shape)) + v_shape = tuple(list(v.shape)) + + # reshape to (num x N x M) + u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1])) + v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1])) + + out = u @ v.permute([0, 2, 1]) + + out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], )) + + return out + + +def ltangent_distance(x, y, omegas): + r"""Localized Tangent distance. + Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T` + Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2` + + :param `torch.tensor` omegas: Three dimensional matrix + :rtype: `torch.tensor` + """ + x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm( + omegas, omegas.permute([0, 2, 1])) + projected_x = x @ p + projected_y = torch.diagonal(y @ p).T + expanded_y = torch.unsqueeze(projected_y, dim=1) + batchwise_difference = expanded_y - projected_x + differences_squared = batchwise_difference**2 + distances = torch.sqrt(torch.sum(differences_squared, dim=2)) + distances = distances.permute(1, 0) + return distances + + class GaussianPrior(torch.nn.Module): + def __init__(self, variance): super().__init__() self.variance = variance @@ -25,6 +64,7 @@ class GaussianPrior(torch.nn.Module): class RankScaledGaussianPrior(torch.nn.Module): + def __init__(self, lambd): super().__init__() self.lambd = lambd @@ -34,6 +74,7 @@ class RankScaledGaussianPrior(torch.nn.Module): class ConnectionTopology(torch.nn.Module): + def __init__(self, agelimit, num_prototypes): super().__init__() self.agelimit = agelimit diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index c781784..66ad1d7 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -4,16 +4,26 @@ import torch from torch.nn.parameter import Parameter from ..core.competitions import wtac -from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance +from ..core.distances import ( + lomega_distance, + omega_distance, + squared_euclidean_distance, +) from ..core.initializers import EyeTransformInitializer -from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss +from ..core.losses import ( + GLVQLoss, + lvq1_loss, + lvq21_loss, +) from ..core.transforms import LinearTransform from ..nn.wrappers import LambdaLayer, LossLayer from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel +from .extras import ltangent_distance, orthogonalization class GLVQ(SupervisedPrototypeModel): """Generalized Learning Vector Quantization.""" + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -98,6 +108,7 @@ class SiameseGLVQ(GLVQ): transformation pipeline are only learned from the inputs. """ + def __init__(self, hparams, backbone=torch.nn.Identity(), @@ -164,6 +175,7 @@ class LVQMLN(SiameseGLVQ): rather in the embedding space. """ + def compute_distances(self, x): latent_protos, _ = self.proto_layer() latent_x = self.backbone(x) @@ -179,6 +191,7 @@ class GRLVQ(SiameseGLVQ): TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise. """ + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -204,6 +217,7 @@ class SiameseGMLVQ(SiameseGLVQ): Implemented as a Siamese network with a linear transformation backbone. """ + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -234,6 +248,7 @@ class GMLVQ(GLVQ): function. This makes it easier to implement a localized variant. """ + def __init__(self, hparams, **kwargs): distance_fn = kwargs.pop("distance_fn", omega_distance) super().__init__(hparams, distance_fn=distance_fn, **kwargs) @@ -268,6 +283,7 @@ class GMLVQ(GLVQ): class LGMLVQ(GMLVQ): """Localized and Generalized Matrix Learning Vector Quantization.""" + def __init__(self, hparams, **kwargs): distance_fn = kwargs.pop("distance_fn", lomega_distance) super().__init__(hparams, distance_fn=distance_fn, **kwargs) @@ -282,8 +298,48 @@ class LGMLVQ(GMLVQ): self.register_parameter("_omega", Parameter(omega)) +class GTLVQ(LGMLVQ): + """Localized and Generalized Tangent Learning Vector Quantization.""" + + def __init__(self, hparams, **kwargs): + distance_fn = kwargs.pop("distance_fn", ltangent_distance) + super().__init__(hparams, distance_fn=distance_fn, **kwargs) + + omega_initializer = kwargs.get("omega_initializer") + + if omega_initializer is not None: + subspace = omega_initializer.generate(self.hparams.input_dim, + self.hparams.latent_dim) + omega = torch.repeat_interleave(subspace.unsqueeze(0), + self.num_prototypes, + dim=0) + else: + omega = torch.rand( + self.num_prototypes, + self.hparams.input_dim, + self.hparams.latent_dim, + device=self.device, + ) + + # Re-register `_omega` to override the one from the super class. + self.register_parameter("_omega", Parameter(omega)) + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + with torch.no_grad(): + self._omega.copy_(orthogonalization(self._omega)) + + +class SiameseGTLVQ(SiameseGLVQ, GTLVQ): + """Generalized Tangent Learning Vector Quantization. + + Implemented as a Siamese network with a linear transformation backbone. + + """ + + class GLVQ1(GLVQ): """Generalized Learning Vector Quantization 1.""" + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) self.loss = LossLayer(lvq1_loss) @@ -292,6 +348,7 @@ class GLVQ1(GLVQ): class GLVQ21(GLVQ): """Generalized Learning Vector Quantization 2.1.""" + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) self.loss = LossLayer(lvq21_loss) @@ -314,3 +371,18 @@ class ImageGMLVQ(ImagePrototypesMixin, GMLVQ): after updates. """ + + +class ImageGTLVQ(ImagePrototypesMixin, GTLVQ): + """GTLVQ for training on image data. + + GTLVQ model that constrains the prototypes to the range [0, 1] by clamping + after updates. + + """ + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + """Constrain the components to the range [0, 1] by clamping after updates.""" + self.proto_layer.components.data.clamp_(0.0, 1.0) + with torch.no_grad(): + self._omega.copy_(orthogonalization(self._omega)) diff --git a/prototorch/models/knn.py b/prototorch/models/knn.py index 0886550..f1a7be5 100644 --- a/prototorch/models/knn.py +++ b/prototorch/models/knn.py @@ -4,13 +4,17 @@ import warnings from ..core.competitions import KNNC from ..core.components import LabeledComponents -from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer +from ..core.initializers import ( + LiteralCompInitializer, + LiteralLabelsInitializer, +) from ..utils.utils import parse_data_arg from .abstract import SupervisedPrototypeModel class KNN(SupervisedPrototypeModel): """K-Nearest-Neighbors classification algorithm.""" + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) diff --git a/prototorch/models/lvq.py b/prototorch/models/lvq.py index f398f93..655f539 100644 --- a/prototorch/models/lvq.py +++ b/prototorch/models/lvq.py @@ -9,6 +9,7 @@ from .glvq import GLVQ class LVQ1(NonGradientMixin, GLVQ): """Learning Vector Quantization 1.""" + def training_step(self, train_batch, batch_idx, optimizer_idx=None): protos, plables = self.proto_layer() x, y = train_batch @@ -38,6 +39,7 @@ class LVQ1(NonGradientMixin, GLVQ): class LVQ21(NonGradientMixin, GLVQ): """Learning Vector Quantization 2.1.""" + def training_step(self, train_batch, batch_idx, optimizer_idx=None): protos, plabels = self.proto_layer() @@ -70,6 +72,7 @@ class MedianLVQ(NonGradientMixin, GLVQ): # TODO Avoid computing distances over and over """ + def __init__(self, hparams, verbose=True, **kwargs): self.verbose = verbose super().__init__(hparams, **kwargs) diff --git a/prototorch/models/probabilistic.py b/prototorch/models/probabilistic.py index c00375f..cb9948a 100644 --- a/prototorch/models/probabilistic.py +++ b/prototorch/models/probabilistic.py @@ -11,6 +11,7 @@ from .glvq import GLVQ, SiameseGMLVQ class CELVQ(GLVQ): """Cross-Entropy Learning Vector Quantization.""" + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -29,6 +30,7 @@ class CELVQ(GLVQ): class ProbabilisticLVQ(GLVQ): + def __init__(self, hparams, rejection_confidence=0.0, **kwargs): super().__init__(hparams, **kwargs) @@ -62,6 +64,7 @@ class ProbabilisticLVQ(GLVQ): class SLVQ(ProbabilisticLVQ): """Soft Learning Vector Quantization.""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.loss = LossLayer(nllr_loss) @@ -70,6 +73,7 @@ class SLVQ(ProbabilisticLVQ): class RSLVQ(ProbabilisticLVQ): """Robust Soft Learning Vector Quantization.""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.loss = LossLayer(rslvq_loss) @@ -81,6 +85,7 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ): TODO: Use Backbone LVQ instead """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.conditional_distribution = RankScaledGaussianPrior( diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index c18f033..ed2a796 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -18,6 +18,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): TODO Allow non-2D grids """ + def __init__(self, hparams, **kwargs): h, w = hparams.get("shape") # Ignore `num_prototypes` @@ -69,6 +70,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): class HeskesSOM(UnsupervisedPrototypeModel): + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -78,6 +80,7 @@ class HeskesSOM(UnsupervisedPrototypeModel): class NeuralGas(UnsupervisedPrototypeModel): + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -110,6 +113,7 @@ class NeuralGas(UnsupervisedPrototypeModel): class GrowingNeuralGas(NeuralGas): + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 9744a9d..49724a0 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -11,6 +11,7 @@ from ..utils.utils import mesh2d class Vis2DAbstract(pl.Callback): + def __init__(self, data, title="Prototype Visualization", @@ -118,6 +119,7 @@ class Vis2DAbstract(pl.Callback): class VisGLVQ2D(Vis2DAbstract): + def on_epoch_end(self, trainer, pl_module): if not self.precheck(trainer): return True @@ -141,6 +143,7 @@ class VisGLVQ2D(Vis2DAbstract): class VisSiameseGLVQ2D(Vis2DAbstract): + def __init__(self, *args, map_protos=True, **kwargs): super().__init__(*args, **kwargs) self.map_protos = map_protos @@ -179,6 +182,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract): class VisGMLVQ2D(Vis2DAbstract): + def __init__(self, *args, ev_proj=True, **kwargs): super().__init__(*args, **kwargs) self.ev_proj = ev_proj @@ -212,6 +216,7 @@ class VisGMLVQ2D(Vis2DAbstract): class VisCBC2D(Vis2DAbstract): + def on_epoch_end(self, trainer, pl_module): if not self.precheck(trainer): return True @@ -235,6 +240,7 @@ class VisCBC2D(Vis2DAbstract): class VisNG2D(Vis2DAbstract): + def on_epoch_end(self, trainer, pl_module): if not self.precheck(trainer): return True @@ -262,6 +268,7 @@ class VisNG2D(Vis2DAbstract): class VisImgComp(Vis2DAbstract): + def __init__(self, *args, random_data=0, diff --git a/setup.cfg b/setup.cfg index 24eeb0b..e3c8135 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +1,23 @@ -[isort] -profile = hug -src_paths = isort, test - [yapf] based_on_style = pep8 spaces_before_comment = 2 split_before_logical_operator = true + +[pylint] +disable = + too-many-arguments, + too-few-public-methods, + fixme, + + +[pycodestyle] +max-line-length = 79 + +[isort] +profile = hug +src_paths = isort, test +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 3 +use_parentheses = True +line_length = 79 diff --git a/tests/test_.py b/tests/test_.py index 88da1bc..d3c12d8 100644 --- a/tests/test_.py +++ b/tests/test_.py @@ -4,6 +4,7 @@ import unittest class TestDummy(unittest.TestCase): + def setUp(self): pass