test(githooks): Add githooks for automatic commit checks
This commit is contained in:
		
							
								
								
									
										54
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,54 @@
 | 
			
		||||
# See https://pre-commit.com for more information
 | 
			
		||||
# See https://pre-commit.com/hooks.html for more hooks
 | 
			
		||||
repos:
 | 
			
		||||
-   repo: https://github.com/pre-commit/pre-commit-hooks
 | 
			
		||||
    rev: v4.0.1
 | 
			
		||||
    hooks:
 | 
			
		||||
    -   id: trailing-whitespace
 | 
			
		||||
    -   id: end-of-file-fixer
 | 
			
		||||
    -   id: check-yaml
 | 
			
		||||
    -   id: check-added-large-files
 | 
			
		||||
    -   id: check-ast
 | 
			
		||||
    -   id: check-case-conflict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
- repo: https://github.com/myint/autoflake
 | 
			
		||||
  rev: v1.4
 | 
			
		||||
  hooks:
 | 
			
		||||
  -   id: autoflake
 | 
			
		||||
 | 
			
		||||
- repo: http://github.com/PyCQA/isort
 | 
			
		||||
  rev: 5.8.0
 | 
			
		||||
  hooks:
 | 
			
		||||
  -   id: isort
 | 
			
		||||
 | 
			
		||||
-   repo: https://github.com/pre-commit/mirrors-mypy
 | 
			
		||||
    rev: 'v0.902'
 | 
			
		||||
    hooks:
 | 
			
		||||
    -   id: mypy
 | 
			
		||||
        files: prototorch
 | 
			
		||||
        additional_dependencies: [types-pkg_resources]
 | 
			
		||||
 | 
			
		||||
-   repo: https://github.com/pre-commit/mirrors-yapf
 | 
			
		||||
    rev: 'v0.31.0'  # Use the sha / tag you want to point at
 | 
			
		||||
    hooks:
 | 
			
		||||
    -   id: yapf
 | 
			
		||||
 | 
			
		||||
-   repo: https://github.com/pre-commit/pygrep-hooks
 | 
			
		||||
    rev: v1.9.0  # Use the ref you want to point at
 | 
			
		||||
    hooks:
 | 
			
		||||
    -   id: python-use-type-annotations
 | 
			
		||||
    -   id: python-no-log-warn
 | 
			
		||||
    -   id: python-check-blanket-noqa
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
-   repo: https://github.com/asottile/pyupgrade
 | 
			
		||||
    rev: v2.19.4
 | 
			
		||||
    hooks:
 | 
			
		||||
    -   id: pyupgrade
 | 
			
		||||
 | 
			
		||||
-   repo: https://github.com/jorisroovers/gitlint
 | 
			
		||||
    rev: "v0.15.1"
 | 
			
		||||
    hooks:
 | 
			
		||||
    -   id: gitlint
 | 
			
		||||
        args: [--contrib=CT1, --ignore=B6, --msg-filename]
 | 
			
		||||
							
								
								
									
										13
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								README.md
									
									
									
									
									
								
							@@ -18,6 +18,19 @@ pip install prototorch_models
 | 
			
		||||
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
 | 
			
		||||
be available for use in your Python environment as `prototorch.models`.
 | 
			
		||||
 | 
			
		||||
## Contribution
 | 
			
		||||
 | 
			
		||||
This repository contains definition for [git hooks](https://githooks.com).
 | 
			
		||||
[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
 | 
			
		||||
Please install the hooks by running
 | 
			
		||||
```bash
 | 
			
		||||
pre-commit install
 | 
			
		||||
pre-commit install --hook-type commit-msg
 | 
			
		||||
```
 | 
			
		||||
before creating the first commit.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## Available models
 | 
			
		||||
 | 
			
		||||
### LVQ Family
 | 
			
		||||
 
 | 
			
		||||
@@ -70,4 +70,3 @@
 | 
			
		||||
    year="2018",
 | 
			
		||||
    publisher="Springer International Publishing",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,12 +1,20 @@
 | 
			
		||||
"""GMLVQ example using the MNIST dataset."""
 | 
			
		||||
 | 
			
		||||
from prototorch.models import ImageGMLVQ
 | 
			
		||||
from prototorch.models.data import train_on_mnist
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch_lightning.utilities.cli import LightningCLI
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ):
 | 
			
		||||
    """Model Definition."""
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
from prototorch.models import ImageGMLVQ
 | 
			
		||||
from prototorch.models.abstract import PrototypeModel
 | 
			
		||||
from prototorch.models.data import MNISTDataModule
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
cli = LightningCLI(GMLVQMNIST)
 | 
			
		||||
class ExperimentClass(ImageGMLVQ):
 | 
			
		||||
    def __init__(self, hparams, **kwargs):
 | 
			
		||||
        super().__init__(hparams,
 | 
			
		||||
                         optimizer=torch.optim.Adam,
 | 
			
		||||
                         prototype_initializer=pt.components.zeros(28 * 28),
 | 
			
		||||
                         **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,11 @@
 | 
			
		||||
model:
 | 
			
		||||
  hparams:
 | 
			
		||||
    input_dim: 784
 | 
			
		||||
    latent_dim: 784
 | 
			
		||||
    distribution:
 | 
			
		||||
      num_classes: 10
 | 
			
		||||
      prototypes_per_class: 2
 | 
			
		||||
    input_dim: 784
 | 
			
		||||
    latent_dim: 784
 | 
			
		||||
    proto_lr: 0.01
 | 
			
		||||
    bb_lr: 0.01
 | 
			
		||||
data:
 | 
			
		||||
  batch_size: 32
 | 
			
		||||
 
 | 
			
		||||
@@ -2,10 +2,11 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 
 | 
			
		||||
@@ -2,11 +2,12 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 
 | 
			
		||||
@@ -2,10 +2,11 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 
 | 
			
		||||
@@ -2,10 +2,11 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 
 | 
			
		||||
@@ -2,11 +2,12 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from sklearn.datasets import load_iris
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 
 | 
			
		||||
@@ -2,11 +2,12 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from matplotlib import pyplot as plt
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def hex_to_rgb(hex_values):
 | 
			
		||||
    for v in hex_values:
 | 
			
		||||
 
 | 
			
		||||
@@ -2,10 +2,11 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 
 | 
			
		||||
@@ -3,10 +3,11 @@
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_matrix(matrix):
 | 
			
		||||
    title = "Lambda matrix"
 | 
			
		||||
 
 | 
			
		||||
@@ -2,13 +2,14 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from sklearn.datasets import load_iris
 | 
			
		||||
from sklearn.preprocessing import StandardScaler
 | 
			
		||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 
 | 
			
		||||
@@ -2,11 +2,12 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from torchvision.transforms import Lambda
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,19 @@
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
"""Prototorch Data Modules
 | 
			
		||||
 | 
			
		||||
This allows to store the used dataset inside a Lightning Module.
 | 
			
		||||
Mainly used for PytorchLightningCLI configurations.
 | 
			
		||||
"""
 | 
			
		||||
from typing import Any, Optional, Type
 | 
			
		||||
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
from torch.utils.data import DataLoader, random_split
 | 
			
		||||
from torch.utils.data import DataLoader, Dataset, random_split
 | 
			
		||||
from torchvision import transforms
 | 
			
		||||
from torchvision.datasets import MNIST
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# MNIST
 | 
			
		||||
class MNISTDataModule(pl.LightningDataModule):
 | 
			
		||||
    def __init__(self, batch_size=32):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
@@ -49,15 +58,67 @@ class MNISTDataModule(pl.LightningDataModule):
 | 
			
		||||
        return mnist_test
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train_on_mnist(batch_size=256) -> type:
 | 
			
		||||
    class DataClass(pl.LightningModule):
 | 
			
		||||
        datamodule = MNISTDataModule(batch_size=batch_size)
 | 
			
		||||
# def train_on_mnist(batch_size=256) -> type:
 | 
			
		||||
#     class DataClass(pl.LightningModule):
 | 
			
		||||
#         datamodule = MNISTDataModule(batch_size=batch_size)
 | 
			
		||||
 | 
			
		||||
        def __init__(self, *args, **kwargs):
 | 
			
		||||
            prototype_initializer = kwargs.pop(
 | 
			
		||||
                "prototype_initializer", pt.components.Zeros((28, 28, 1)))
 | 
			
		||||
            super().__init__(*args,
 | 
			
		||||
                             prototype_initializer=prototype_initializer,
 | 
			
		||||
                             **kwargs)
 | 
			
		||||
#         def __init__(self, *args, **kwargs):
 | 
			
		||||
#             prototype_initializer = kwargs.pop(
 | 
			
		||||
#                 "prototype_initializer", pt.components.Zeros((28, 28, 1)))
 | 
			
		||||
#             super().__init__(*args,
 | 
			
		||||
#                              prototype_initializer=prototype_initializer,
 | 
			
		||||
#                              **kwargs)
 | 
			
		||||
 | 
			
		||||
    return DataClass
 | 
			
		||||
#     dc: Type[DataClass] = DataClass
 | 
			
		||||
#     return dc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ABSTRACT
 | 
			
		||||
class GeneralDataModule(pl.LightningDataModule):
 | 
			
		||||
    def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.train_dataset = dataset
 | 
			
		||||
        self.batch_size = batch_size
 | 
			
		||||
 | 
			
		||||
    def train_dataloader(self) -> DataLoader:
 | 
			
		||||
        return DataLoader(self.train_dataset, batch_size=self.batch_size)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
 | 
			
		||||
#     class DataClass(pl.LightningModule):
 | 
			
		||||
#         datamodule = GeneralDataModule(dataset, batch_size)
 | 
			
		||||
#         datashape = dataset[0][0].shape
 | 
			
		||||
#         example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
#         def __init__(self, *args: Any, **kwargs: Any) -> None:
 | 
			
		||||
#             prototype_initializer = kwargs.pop(
 | 
			
		||||
#                 "prototype_initializer",
 | 
			
		||||
#                 pt.components.Zeros(self.datashape),
 | 
			
		||||
#             )
 | 
			
		||||
#             super().__init__(*args,
 | 
			
		||||
#                              prototype_initializer=prototype_initializer,
 | 
			
		||||
#                              **kwargs)
 | 
			
		||||
 | 
			
		||||
#     return DataClass
 | 
			
		||||
 | 
			
		||||
# if __name__ == "__main__":
 | 
			
		||||
#     from prototorch.models import GLVQ
 | 
			
		||||
 | 
			
		||||
#     demo_dataset = pt.datasets.Iris()
 | 
			
		||||
 | 
			
		||||
#     TrainingClass: Type = train_on_dataset(demo_dataset)
 | 
			
		||||
 | 
			
		||||
#     class DemoGLVQ(TrainingClass, GLVQ):
 | 
			
		||||
#         """Model Definition."""
 | 
			
		||||
 | 
			
		||||
#     # Hyperparameters
 | 
			
		||||
#     hparams = dict(
 | 
			
		||||
#         distribution={
 | 
			
		||||
#             "num_classes": 3,
 | 
			
		||||
#             "prototypes_per_class": 4
 | 
			
		||||
#         },
 | 
			
		||||
#         lr=0.01,
 | 
			
		||||
#     )
 | 
			
		||||
 | 
			
		||||
#     initialized = DemoGLVQ(hparams)
 | 
			
		||||
#     print(initialized)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,11 +3,8 @@
 | 
			
		||||
import torch
 | 
			
		||||
from prototorch.functions.activations import get_activation
 | 
			
		||||
from prototorch.functions.competitions import wtac
 | 
			
		||||
from prototorch.functions.distances import (
 | 
			
		||||
    lomega_distance,
 | 
			
		||||
    omega_distance,
 | 
			
		||||
    squared_euclidean_distance,
 | 
			
		||||
)
 | 
			
		||||
from prototorch.functions.distances import (lomega_distance, omega_distance,
 | 
			
		||||
                                            squared_euclidean_distance)
 | 
			
		||||
from prototorch.functions.helper import get_flat
 | 
			
		||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
			
		||||
from prototorch.modules import LambdaLayer, LossLayer
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										15
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								setup.py
									
									
									
									
									
								
							@@ -1,10 +1,12 @@
 | 
			
		||||
"""
 | 
			
		||||
  _____           _     _______             _
 | 
			
		||||
 |  __ \         | |   |__   __|           | |
 | 
			
		||||
 | |__) | __ ___ | |_ ___ | | ___  _ __ ___| |__
 | 
			
		||||
 |  ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
 | 
			
		||||
 | |   | | | (_) | || (_) | | (_) | | | (__| | | |
 | 
			
		||||
 |_|   |_|  \___/ \__\___/|_|\___/|_|  \___|_| |_|Plugin
 | 
			
		||||
 | 
			
		||||
 ######
 | 
			
		||||
 #     # #####   ####  #####  ####  #####  ####  #####   ####  #    #
 | 
			
		||||
 #     # #    # #    #   #   #    #   #   #    # #    # #    # #    #
 | 
			
		||||
 ######  #    # #    #   #   #    #   #   #    # #    # #      ######
 | 
			
		||||
 #       #####  #    #   #   #    #   #   #    # #####  #      #    #
 | 
			
		||||
 #       #   #  #    #   #   #    #   #   #    # #   #  #    # #    #
 | 
			
		||||
 #       #    #  ####    #    ####    #    ####  #    #  ####  #    #Plugin
 | 
			
		||||
 | 
			
		||||
ProtoTorch models Plugin Package
 | 
			
		||||
"""
 | 
			
		||||
@@ -29,6 +31,7 @@ CLI = [
 | 
			
		||||
]
 | 
			
		||||
DEV = [
 | 
			
		||||
    "bumpversion",
 | 
			
		||||
    "pre-commit",
 | 
			
		||||
]
 | 
			
		||||
DOCS = [
 | 
			
		||||
    "recommonmark",
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user