test(githooks): Add githooks for automatic commit checks
This commit is contained in:
parent
c87ed5ba8b
commit
8956ee75ad
54
.pre-commit-config.yaml
Normal file
54
.pre-commit-config.yaml
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# See https://pre-commit.com for more information
|
||||||
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.0.1
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-ast
|
||||||
|
- id: check-case-conflict
|
||||||
|
|
||||||
|
|
||||||
|
- repo: https://github.com/myint/autoflake
|
||||||
|
rev: v1.4
|
||||||
|
hooks:
|
||||||
|
- id: autoflake
|
||||||
|
|
||||||
|
- repo: http://github.com/PyCQA/isort
|
||||||
|
rev: 5.8.0
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
rev: 'v0.902'
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
files: prototorch
|
||||||
|
additional_dependencies: [types-pkg_resources]
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||||
|
rev: 'v0.31.0' # Use the sha / tag you want to point at
|
||||||
|
hooks:
|
||||||
|
- id: yapf
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||||
|
rev: v1.9.0 # Use the ref you want to point at
|
||||||
|
hooks:
|
||||||
|
- id: python-use-type-annotations
|
||||||
|
- id: python-no-log-warn
|
||||||
|
- id: python-check-blanket-noqa
|
||||||
|
|
||||||
|
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v2.19.4
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
|
||||||
|
- repo: https://github.com/jorisroovers/gitlint
|
||||||
|
rev: "v0.15.1"
|
||||||
|
hooks:
|
||||||
|
- id: gitlint
|
||||||
|
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
13
README.md
13
README.md
@ -18,6 +18,19 @@ pip install prototorch_models
|
|||||||
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
|
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
|
||||||
be available for use in your Python environment as `prototorch.models`.
|
be available for use in your Python environment as `prototorch.models`.
|
||||||
|
|
||||||
|
## Contribution
|
||||||
|
|
||||||
|
This repository contains definition for [git hooks](https://githooks.com).
|
||||||
|
[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
|
||||||
|
Please install the hooks by running
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
pre-commit install --hook-type commit-msg
|
||||||
|
```
|
||||||
|
before creating the first commit.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Available models
|
## Available models
|
||||||
|
|
||||||
### LVQ Family
|
### LVQ Family
|
||||||
|
@ -70,4 +70,3 @@
|
|||||||
year="2018",
|
year="2018",
|
||||||
publisher="Springer International Publishing",
|
publisher="Springer International Publishing",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,12 +1,20 @@
|
|||||||
"""GMLVQ example using the MNIST dataset."""
|
"""GMLVQ example using the MNIST dataset."""
|
||||||
|
|
||||||
from prototorch.models import ImageGMLVQ
|
import torch
|
||||||
from prototorch.models.data import train_on_mnist
|
|
||||||
from pytorch_lightning.utilities.cli import LightningCLI
|
from pytorch_lightning.utilities.cli import LightningCLI
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ):
|
from prototorch.models import ImageGMLVQ
|
||||||
"""Model Definition."""
|
from prototorch.models.abstract import PrototypeModel
|
||||||
|
from prototorch.models.data import MNISTDataModule
|
||||||
|
|
||||||
|
|
||||||
cli = LightningCLI(GMLVQMNIST)
|
class ExperimentClass(ImageGMLVQ):
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams,
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
prototype_initializer=pt.components.zeros(28 * 28),
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
model:
|
model:
|
||||||
hparams:
|
hparams:
|
||||||
|
input_dim: 784
|
||||||
|
latent_dim: 784
|
||||||
distribution:
|
distribution:
|
||||||
num_classes: 10
|
num_classes: 10
|
||||||
prototypes_per_class: 2
|
prototypes_per_class: 2
|
||||||
input_dim: 784
|
|
||||||
latent_dim: 784
|
|
||||||
proto_lr: 0.01
|
proto_lr: 0.01
|
||||||
bb_lr: 0.01
|
bb_lr: 0.01
|
||||||
|
data:
|
||||||
|
batch_size: 32
|
||||||
|
@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -2,11 +2,12 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -2,11 +2,12 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -2,11 +2,12 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
|
||||||
def hex_to_rgb(hex_values):
|
def hex_to_rgb(hex_values):
|
||||||
for v in hex_values:
|
for v in hex_values:
|
||||||
|
@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -3,10 +3,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
|
||||||
def plot_matrix(matrix):
|
def plot_matrix(matrix):
|
||||||
title = "Lambda matrix"
|
title = "Lambda matrix"
|
||||||
|
@ -2,13 +2,14 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -2,11 +2,12 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import Lambda
|
from torchvision.transforms import Lambda
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -1,10 +1,19 @@
|
|||||||
import prototorch as pt
|
"""Prototorch Data Modules
|
||||||
|
|
||||||
|
This allows to store the used dataset inside a Lightning Module.
|
||||||
|
Mainly used for PytorchLightningCLI configurations.
|
||||||
|
"""
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import DataLoader, Dataset, random_split
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
|
||||||
|
# MNIST
|
||||||
class MNISTDataModule(pl.LightningDataModule):
|
class MNISTDataModule(pl.LightningDataModule):
|
||||||
def __init__(self, batch_size=32):
|
def __init__(self, batch_size=32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -49,15 +58,67 @@ class MNISTDataModule(pl.LightningDataModule):
|
|||||||
return mnist_test
|
return mnist_test
|
||||||
|
|
||||||
|
|
||||||
def train_on_mnist(batch_size=256) -> type:
|
# def train_on_mnist(batch_size=256) -> type:
|
||||||
class DataClass(pl.LightningModule):
|
# class DataClass(pl.LightningModule):
|
||||||
datamodule = MNISTDataModule(batch_size=batch_size)
|
# datamodule = MNISTDataModule(batch_size=batch_size)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
# def __init__(self, *args, **kwargs):
|
||||||
prototype_initializer = kwargs.pop(
|
# prototype_initializer = kwargs.pop(
|
||||||
"prototype_initializer", pt.components.Zeros((28, 28, 1)))
|
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
|
||||||
super().__init__(*args,
|
# super().__init__(*args,
|
||||||
prototype_initializer=prototype_initializer,
|
# prototype_initializer=prototype_initializer,
|
||||||
**kwargs)
|
# **kwargs)
|
||||||
|
|
||||||
return DataClass
|
# dc: Type[DataClass] = DataClass
|
||||||
|
# return dc
|
||||||
|
|
||||||
|
|
||||||
|
# ABSTRACT
|
||||||
|
class GeneralDataModule(pl.LightningDataModule):
|
||||||
|
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.train_dataset = dataset
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def train_dataloader(self) -> DataLoader:
|
||||||
|
return DataLoader(self.train_dataset, batch_size=self.batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
|
||||||
|
# class DataClass(pl.LightningModule):
|
||||||
|
# datamodule = GeneralDataModule(dataset, batch_size)
|
||||||
|
# datashape = dataset[0][0].shape
|
||||||
|
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
|
||||||
|
|
||||||
|
# def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
# prototype_initializer = kwargs.pop(
|
||||||
|
# "prototype_initializer",
|
||||||
|
# pt.components.Zeros(self.datashape),
|
||||||
|
# )
|
||||||
|
# super().__init__(*args,
|
||||||
|
# prototype_initializer=prototype_initializer,
|
||||||
|
# **kwargs)
|
||||||
|
|
||||||
|
# return DataClass
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# from prototorch.models import GLVQ
|
||||||
|
|
||||||
|
# demo_dataset = pt.datasets.Iris()
|
||||||
|
|
||||||
|
# TrainingClass: Type = train_on_dataset(demo_dataset)
|
||||||
|
|
||||||
|
# class DemoGLVQ(TrainingClass, GLVQ):
|
||||||
|
# """Model Definition."""
|
||||||
|
|
||||||
|
# # Hyperparameters
|
||||||
|
# hparams = dict(
|
||||||
|
# distribution={
|
||||||
|
# "num_classes": 3,
|
||||||
|
# "prototypes_per_class": 4
|
||||||
|
# },
|
||||||
|
# lr=0.01,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# initialized = DemoGLVQ(hparams)
|
||||||
|
# print(initialized)
|
||||||
|
@ -3,11 +3,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from prototorch.functions.activations import get_activation
|
from prototorch.functions.activations import get_activation
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import (
|
from prototorch.functions.distances import (lomega_distance, omega_distance,
|
||||||
lomega_distance,
|
squared_euclidean_distance)
|
||||||
omega_distance,
|
|
||||||
squared_euclidean_distance,
|
|
||||||
)
|
|
||||||
from prototorch.functions.helper import get_flat
|
from prototorch.functions.helper import get_flat
|
||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
from prototorch.modules import LambdaLayer, LossLayer
|
from prototorch.modules import LambdaLayer, LossLayer
|
||||||
|
15
setup.py
15
setup.py
@ -1,10 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
_____ _ _______ _
|
|
||||||
| __ \ | | |__ __| | |
|
######
|
||||||
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
|
# # ##### #### ##### #### ##### #### ##### #### # #
|
||||||
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
|
# # # # # # # # # # # # # # # # # #
|
||||||
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|
###### # # # # # # # # # # # # # ######
|
||||||
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|Plugin
|
# ##### # # # # # # # # ##### # # #
|
||||||
|
# # # # # # # # # # # # # # # # #
|
||||||
|
# # # #### # #### # #### # # #### # #Plugin
|
||||||
|
|
||||||
ProtoTorch models Plugin Package
|
ProtoTorch models Plugin Package
|
||||||
"""
|
"""
|
||||||
@ -29,6 +31,7 @@ CLI = [
|
|||||||
]
|
]
|
||||||
DEV = [
|
DEV = [
|
||||||
"bumpversion",
|
"bumpversion",
|
||||||
|
"pre-commit",
|
||||||
]
|
]
|
||||||
DOCS = [
|
DOCS = [
|
||||||
"recommonmark",
|
"recommonmark",
|
||||||
|
Loading…
Reference in New Issue
Block a user