test(githooks): Add githooks for automatic commit checks

This commit is contained in:
Alexander Engelsberger 2021-06-16 16:16:34 +02:00
parent c87ed5ba8b
commit 8956ee75ad
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
24 changed files with 196 additions and 49 deletions

54
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,54 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-ast
- id: check-case-conflict
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
- repo: http://github.com/PyCQA/isort
rev: 5.8.0
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.902'
hooks:
- id: mypy
files: prototorch
additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf
rev: 'v0.31.0' # Use the sha / tag you want to point at
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0 # Use the ref you want to point at
hooks:
- id: python-use-type-annotations
- id: python-no-log-warn
- id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v2.19.4
hooks:
- id: pyupgrade
- repo: https://github.com/jorisroovers/gitlint
rev: "v0.15.1"
hooks:
- id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]

View File

@ -18,6 +18,19 @@ pip install prototorch_models
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
be available for use in your Python environment as `prototorch.models`.
## Contribution
This repository contains definition for [git hooks](https://githooks.com).
[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
Please install the hooks by running
```bash
pre-commit install
pre-commit install --hook-type commit-msg
```
before creating the first commit.
## Available models
### LVQ Family

View File

@ -70,4 +70,3 @@
year="2018",
publisher="Springer International Publishing",
}

View File

@ -1,12 +1,20 @@
"""GMLVQ example using the MNIST dataset."""
from prototorch.models import ImageGMLVQ
from prototorch.models.data import train_on_mnist
import torch
from pytorch_lightning.utilities.cli import LightningCLI
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ):
"""Model Definition."""
import prototorch as pt
from prototorch.models import ImageGMLVQ
from prototorch.models.abstract import PrototypeModel
from prototorch.models.data import MNISTDataModule
cli = LightningCLI(GMLVQMNIST)
class ExperimentClass(ImageGMLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.zeros(28 * 28),
**kwargs)
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)

View File

@ -1,9 +1,11 @@
model:
hparams:
input_dim: 784
latent_dim: 784
distribution:
num_classes: 10
prototypes_per_class: 2
input_dim: 784
latent_dim: 784
proto_lr: 0.01
bb_lr: 0.01
data:
batch_size: 32

View File

@ -2,10 +2,11 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@ -2,10 +2,11 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@ -2,10 +2,11 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from sklearn.datasets import load_iris
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
import prototorch as pt
def hex_to_rgb(hex_values):
for v in hex_values:

View File

@ -2,10 +2,11 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@ -3,10 +3,11 @@
import argparse
import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
def plot_matrix(matrix):
title = "Lambda matrix"

View File

@ -2,13 +2,14 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torchvision.transforms import Lambda
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@ -1,10 +1,19 @@
import prototorch as pt
"""Prototorch Data Modules
This allows to store the used dataset inside a Lightning Module.
Mainly used for PytorchLightningCLI configurations.
"""
from typing import Any, Optional, Type
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import prototorch as pt
# MNIST
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
@ -49,15 +58,67 @@ class MNISTDataModule(pl.LightningDataModule):
return mnist_test
def train_on_mnist(batch_size=256) -> type:
class DataClass(pl.LightningModule):
datamodule = MNISTDataModule(batch_size=batch_size)
# def train_on_mnist(batch_size=256) -> type:
# class DataClass(pl.LightningModule):
# datamodule = MNISTDataModule(batch_size=batch_size)
def __init__(self, *args, **kwargs):
prototype_initializer = kwargs.pop(
"prototype_initializer", pt.components.Zeros((28, 28, 1)))
super().__init__(*args,
prototype_initializer=prototype_initializer,
**kwargs)
# def __init__(self, *args, **kwargs):
# prototype_initializer = kwargs.pop(
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
return DataClass
# dc: Type[DataClass] = DataClass
# return dc
# ABSTRACT
class GeneralDataModule(pl.LightningDataModule):
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
super().__init__()
self.train_dataset = dataset
self.batch_size = batch_size
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset, batch_size=self.batch_size)
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
# class DataClass(pl.LightningModule):
# datamodule = GeneralDataModule(dataset, batch_size)
# datashape = dataset[0][0].shape
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
# def __init__(self, *args: Any, **kwargs: Any) -> None:
# prototype_initializer = kwargs.pop(
# "prototype_initializer",
# pt.components.Zeros(self.datashape),
# )
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# return DataClass
# if __name__ == "__main__":
# from prototorch.models import GLVQ
# demo_dataset = pt.datasets.Iris()
# TrainingClass: Type = train_on_dataset(demo_dataset)
# class DemoGLVQ(TrainingClass, GLVQ):
# """Model Definition."""
# # Hyperparameters
# hparams = dict(
# distribution={
# "num_classes": 3,
# "prototypes_per_class": 4
# },
# lr=0.01,
# )
# initialized = DemoGLVQ(hparams)
# print(initialized)

View File

@ -3,11 +3,8 @@
import torch
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.distances import (lomega_distance, omega_distance,
squared_euclidean_distance)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer, LossLayer

View File

@ -1,10 +1,12 @@
"""
_____ _ _______ _
| __ \ | | |__ __| | |
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|Plugin
######
# # ##### #### ##### #### ##### #### ##### #### # #
# # # # # # # # # # # # # # # # # #
###### # # # # # # # # # # # # # ######
# ##### # # # # # # # # ##### # # #
# # # # # # # # # # # # # # # # #
# # # #### # #### # #### # # #### # #Plugin
ProtoTorch models Plugin Package
"""
@ -29,6 +31,7 @@ CLI = [
]
DEV = [
"bumpversion",
"pre-commit",
]
DOCS = [
"recommonmark",