test(githooks): Add githooks for automatic commit checks

This commit is contained in:
Alexander Engelsberger 2021-06-16 16:16:34 +02:00
parent c87ed5ba8b
commit 8956ee75ad
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
24 changed files with 196 additions and 49 deletions

View File

@ -3,7 +3,7 @@ current_version = 0.1.7
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = serialize =
{major}.{minor}.{patch} {major}.{minor}.{patch}
[bumpversion:file:setup.py] [bumpversion:file:setup.py]

2
.gitignore vendored
View File

@ -138,4 +138,4 @@ lightning_logs/
# Pytorch Models or Weights # Pytorch Models or Weights
# If necessary make exceptions for single pretrained models # If necessary make exceptions for single pretrained models
*.pt *.pt

54
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,54 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-ast
- id: check-case-conflict
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
- repo: http://github.com/PyCQA/isort
rev: 5.8.0
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.902'
hooks:
- id: mypy
files: prototorch
additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf
rev: 'v0.31.0' # Use the sha / tag you want to point at
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0 # Use the ref you want to point at
hooks:
- id: python-use-type-annotations
- id: python-no-log-warn
- id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v2.19.4
hooks:
- id: pyupgrade
- repo: https://github.com/jorisroovers/gitlint
rev: "v0.15.1"
hooks:
- id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]

View File

@ -18,6 +18,19 @@ pip install prototorch_models
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
be available for use in your Python environment as `prototorch.models`. be available for use in your Python environment as `prototorch.models`.
## Contribution
This repository contains definition for [git hooks](https://githooks.com).
[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
Please install the hooks by running
```bash
pre-commit install
pre-commit install --hook-type commit-msg
```
before creating the first commit.
## Available models ## Available models
### LVQ Family ### LVQ Family

View File

@ -4,4 +4,4 @@ Abstract Models
======================================== ========================================
.. automodule:: prototorch.models.abstract .. automodule:: prototorch.models.abstract
:members: :members:
:undoc-members: :undoc-members:

View File

@ -37,4 +37,4 @@ These models have been published in the past and have been adapted to the Protot
Customizable Customizable
----------------------------------------- -----------------------------------------
Prototorch Models also contains the building blocks to build own models with PyTorch-Lightning and Prototorch. Prototorch Models also contains the building blocks to build own models with PyTorch-Lightning and Prototorch.

View File

@ -71,7 +71,7 @@ Probabilistic Models
Probabilistic variants assume, that the prototypes generate a probability distribution over the classes. Probabilistic variants assume, that the prototypes generate a probability distribution over the classes.
For a test sample they return a distribution instead of a class assignment. For a test sample they return a distribution instead of a class assignment.
The following two algorihms were presented by :cite:t:`seo2003` . The following two algorihms were presented by :cite:t:`seo2003` .
Every prototypes is a center of a gaussian distribution of its class, generating a mixture model. Every prototypes is a center of a gaussian distribution of its class, generating a mixture model.
.. autoclass:: prototorch.models.probabilistic.SLVQ .. autoclass:: prototorch.models.probabilistic.SLVQ
@ -114,4 +114,4 @@ The visulizations can be shown in a seperate window and inside a tensorboard.
Bibliography Bibliography
======================================== ========================================
.. bibliography:: .. bibliography::

View File

@ -70,4 +70,3 @@
year="2018", year="2018",
publisher="Springer International Publishing", publisher="Springer International Publishing",
} }

View File

@ -5,4 +5,4 @@ Examples in this folder use the experimental [Lightning CLI](https://pytorch-lig
To use the example run To use the example run
``` ```
python gmlvq.py --config gmlvq.yaml python gmlvq.py --config gmlvq.yaml
``` ```

View File

@ -1,12 +1,20 @@
"""GMLVQ example using the MNIST dataset.""" """GMLVQ example using the MNIST dataset."""
from prototorch.models import ImageGMLVQ import torch
from prototorch.models.data import train_on_mnist
from pytorch_lightning.utilities.cli import LightningCLI from pytorch_lightning.utilities.cli import LightningCLI
import prototorch as pt
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ): from prototorch.models import ImageGMLVQ
"""Model Definition.""" from prototorch.models.abstract import PrototypeModel
from prototorch.models.data import MNISTDataModule
cli = LightningCLI(GMLVQMNIST) class ExperimentClass(ImageGMLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.zeros(28 * 28),
**kwargs)
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)

View File

@ -1,9 +1,11 @@
model: model:
hparams: hparams:
input_dim: 784
latent_dim: 784
distribution: distribution:
num_classes: 10 num_classes: 10
prototypes_per_class: 2 prototypes_per_class: 2
input_dim: 784
latent_dim: 784
proto_lr: 0.01 proto_lr: 0.01
bb_lr: 0.01 bb_lr: 0.01
data:
batch_size: 32

View File

@ -2,10 +2,11 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,10 +2,11 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,10 +2,11 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import prototorch as pt
def hex_to_rgb(hex_values): def hex_to_rgb(hex_values):
for v in hex_values: for v in hex_values:

View File

@ -2,10 +2,11 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -3,10 +3,11 @@
import argparse import argparse
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
def plot_matrix(matrix): def plot_matrix(matrix):
title = "Lambda matrix" title = "Lambda matrix"

View File

@ -2,13 +2,14 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -2,11 +2,12 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torchvision.transforms import Lambda from torchvision.transforms import Lambda
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -1,10 +1,19 @@
import prototorch as pt """Prototorch Data Modules
This allows to store the used dataset inside a Lightning Module.
Mainly used for PytorchLightningCLI configurations.
"""
from typing import Any, Optional, Type
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
import prototorch as pt
# MNIST
class MNISTDataModule(pl.LightningDataModule): class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32): def __init__(self, batch_size=32):
super().__init__() super().__init__()
@ -49,15 +58,67 @@ class MNISTDataModule(pl.LightningDataModule):
return mnist_test return mnist_test
def train_on_mnist(batch_size=256) -> type: # def train_on_mnist(batch_size=256) -> type:
class DataClass(pl.LightningModule): # class DataClass(pl.LightningModule):
datamodule = MNISTDataModule(batch_size=batch_size) # datamodule = MNISTDataModule(batch_size=batch_size)
def __init__(self, *args, **kwargs): # def __init__(self, *args, **kwargs):
prototype_initializer = kwargs.pop( # prototype_initializer = kwargs.pop(
"prototype_initializer", pt.components.Zeros((28, 28, 1))) # "prototype_initializer", pt.components.Zeros((28, 28, 1)))
super().__init__(*args, # super().__init__(*args,
prototype_initializer=prototype_initializer, # prototype_initializer=prototype_initializer,
**kwargs) # **kwargs)
return DataClass # dc: Type[DataClass] = DataClass
# return dc
# ABSTRACT
class GeneralDataModule(pl.LightningDataModule):
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
super().__init__()
self.train_dataset = dataset
self.batch_size = batch_size
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset, batch_size=self.batch_size)
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
# class DataClass(pl.LightningModule):
# datamodule = GeneralDataModule(dataset, batch_size)
# datashape = dataset[0][0].shape
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
# def __init__(self, *args: Any, **kwargs: Any) -> None:
# prototype_initializer = kwargs.pop(
# "prototype_initializer",
# pt.components.Zeros(self.datashape),
# )
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# return DataClass
# if __name__ == "__main__":
# from prototorch.models import GLVQ
# demo_dataset = pt.datasets.Iris()
# TrainingClass: Type = train_on_dataset(demo_dataset)
# class DemoGLVQ(TrainingClass, GLVQ):
# """Model Definition."""
# # Hyperparameters
# hparams = dict(
# distribution={
# "num_classes": 3,
# "prototypes_per_class": 4
# },
# lr=0.01,
# )
# initialized = DemoGLVQ(hparams)
# print(initialized)

View File

@ -3,11 +3,8 @@
import torch import torch
from prototorch.functions.activations import get_activation from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import ( from prototorch.functions.distances import (lomega_distance, omega_distance,
lomega_distance, squared_euclidean_distance)
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer, LossLayer from prototorch.modules import LambdaLayer, LossLayer

View File

@ -1,10 +1,12 @@
""" """
_____ _ _______ _
| __ \ | | |__ __| | | ######
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__ # # ##### #### ##### #### ##### #### ##### #### # #
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \ # # # # # # # # # # # # # # # # # #
| | | | | (_) | || (_) | | (_) | | | (__| | | | ###### # # # # # # # # # # # # # ######
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|Plugin # ##### # # # # # # # # ##### # # #
# # # # # # # # # # # # # # # # #
# # # #### # #### # #### # # #### # #Plugin
ProtoTorch models Plugin Package ProtoTorch models Plugin Package
""" """
@ -29,6 +31,7 @@ CLI = [
] ]
DEV = [ DEV = [
"bumpversion", "bumpversion",
"pre-commit",
] ]
DOCS = [ DOCS = [
"recommonmark", "recommonmark",