test(githooks): Add githooks for automatic commit checks

This commit is contained in:
Alexander Engelsberger 2021-06-16 16:16:34 +02:00
parent c87ed5ba8b
commit 8956ee75ad
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
24 changed files with 196 additions and 49 deletions

54
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,54 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-ast
- id: check-case-conflict
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
- repo: http://github.com/PyCQA/isort
rev: 5.8.0
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.902'
hooks:
- id: mypy
files: prototorch
additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf
rev: 'v0.31.0' # Use the sha / tag you want to point at
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0 # Use the ref you want to point at
hooks:
- id: python-use-type-annotations
- id: python-no-log-warn
- id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v2.19.4
hooks:
- id: pyupgrade
- repo: https://github.com/jorisroovers/gitlint
rev: "v0.15.1"
hooks:
- id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]

View File

@ -18,6 +18,19 @@ pip install prototorch_models
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
be available for use in your Python environment as `prototorch.models`. be available for use in your Python environment as `prototorch.models`.
## Contribution
This repository contains definition for [git hooks](https://githooks.com).
[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
Please install the hooks by running
```bash
pre-commit install
pre-commit install --hook-type commit-msg
```
before creating the first commit.
## Available models ## Available models
### LVQ Family ### LVQ Family

View File

@ -70,4 +70,3 @@
year="2018", year="2018",
publisher="Springer International Publishing", publisher="Springer International Publishing",
} }

View File

@ -1,12 +1,20 @@
"""GMLVQ example using the MNIST dataset.""" """GMLVQ example using the MNIST dataset."""
from prototorch.models import ImageGMLVQ import torch
from prototorch.models.data import train_on_mnist
from pytorch_lightning.utilities.cli import LightningCLI from pytorch_lightning.utilities.cli import LightningCLI
import prototorch as pt
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ): from prototorch.models import ImageGMLVQ
"""Model Definition.""" from prototorch.models.abstract import PrototypeModel
from prototorch.models.data import MNISTDataModule
cli = LightningCLI(GMLVQMNIST) class ExperimentClass(ImageGMLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.zeros(28 * 28),
**kwargs)
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)

View File

@ -1,9 +1,11 @@
model: model:
hparams: hparams:
input_dim: 784
latent_dim: 784
distribution: distribution:
num_classes: 10 num_classes: 10
prototypes_per_class: 2 prototypes_per_class: 2
input_dim: 784
latent_dim: 784
proto_lr: 0.01 proto_lr: 0.01
bb_lr: 0.01 bb_lr: 0.01
data:
batch_size: 32

View File

@ -2,10 +2,11 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,10 +2,11 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,10 +2,11 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import prototorch as pt
def hex_to_rgb(hex_values): def hex_to_rgb(hex_values):
for v in hex_values: for v in hex_values:

View File

@ -2,10 +2,11 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -3,10 +3,11 @@
import argparse import argparse
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
def plot_matrix(matrix): def plot_matrix(matrix):
title = "Lambda matrix" title = "Lambda matrix"

View File

@ -2,13 +2,14 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torchvision.transforms import Lambda from torchvision.transforms import Lambda
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -1,10 +1,19 @@
import prototorch as pt """Prototorch Data Modules
This allows to store the used dataset inside a Lightning Module.
Mainly used for PytorchLightningCLI configurations.
"""
from typing import Any, Optional, Type
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
import prototorch as pt
# MNIST
class MNISTDataModule(pl.LightningDataModule): class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32): def __init__(self, batch_size=32):
super().__init__() super().__init__()
@ -49,15 +58,67 @@ class MNISTDataModule(pl.LightningDataModule):
return mnist_test return mnist_test
def train_on_mnist(batch_size=256) -> type: # def train_on_mnist(batch_size=256) -> type:
class DataClass(pl.LightningModule): # class DataClass(pl.LightningModule):
datamodule = MNISTDataModule(batch_size=batch_size) # datamodule = MNISTDataModule(batch_size=batch_size)
def __init__(self, *args, **kwargs): # def __init__(self, *args, **kwargs):
prototype_initializer = kwargs.pop( # prototype_initializer = kwargs.pop(
"prototype_initializer", pt.components.Zeros((28, 28, 1))) # "prototype_initializer", pt.components.Zeros((28, 28, 1)))
super().__init__(*args, # super().__init__(*args,
prototype_initializer=prototype_initializer, # prototype_initializer=prototype_initializer,
**kwargs) # **kwargs)
return DataClass # dc: Type[DataClass] = DataClass
# return dc
# ABSTRACT
class GeneralDataModule(pl.LightningDataModule):
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
super().__init__()
self.train_dataset = dataset
self.batch_size = batch_size
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset, batch_size=self.batch_size)
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
# class DataClass(pl.LightningModule):
# datamodule = GeneralDataModule(dataset, batch_size)
# datashape = dataset[0][0].shape
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
# def __init__(self, *args: Any, **kwargs: Any) -> None:
# prototype_initializer = kwargs.pop(
# "prototype_initializer",
# pt.components.Zeros(self.datashape),
# )
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# return DataClass
# if __name__ == "__main__":
# from prototorch.models import GLVQ
# demo_dataset = pt.datasets.Iris()
# TrainingClass: Type = train_on_dataset(demo_dataset)
# class DemoGLVQ(TrainingClass, GLVQ):
# """Model Definition."""
# # Hyperparameters
# hparams = dict(
# distribution={
# "num_classes": 3,
# "prototypes_per_class": 4
# },
# lr=0.01,
# )
# initialized = DemoGLVQ(hparams)
# print(initialized)

View File

@ -3,11 +3,8 @@
import torch import torch
from prototorch.functions.activations import get_activation from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import ( from prototorch.functions.distances import (lomega_distance, omega_distance,
lomega_distance, squared_euclidean_distance)
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer, LossLayer from prototorch.modules import LambdaLayer, LossLayer

View File

@ -1,10 +1,12 @@
""" """
_____ _ _______ _
| __ \ | | |__ __| | | ######
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__ # # ##### #### ##### #### ##### #### ##### #### # #
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \ # # # # # # # # # # # # # # # # # #
| | | | | (_) | || (_) | | (_) | | | (__| | | | ###### # # # # # # # # # # # # # ######
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|Plugin # ##### # # # # # # # # ##### # # #
# # # # # # # # # # # # # # # # #
# # # #### # #### # #### # # #### # #Plugin
ProtoTorch models Plugin Package ProtoTorch models Plugin Package
""" """
@ -29,6 +31,7 @@ CLI = [
] ]
DEV = [ DEV = [
"bumpversion", "bumpversion",
"pre-commit",
] ]
DOCS = [ DOCS = [
"recommonmark", "recommonmark",