test(githooks): Add githooks for automatic commit checks
This commit is contained in:
parent
c87ed5ba8b
commit
8956ee75ad
54
.pre-commit-config.yaml
Normal file
54
.pre-commit-config.yaml
Normal file
@ -0,0 +1,54 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.0.1
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- id: check-ast
|
||||
- id: check-case-conflict
|
||||
|
||||
|
||||
- repo: https://github.com/myint/autoflake
|
||||
rev: v1.4
|
||||
hooks:
|
||||
- id: autoflake
|
||||
|
||||
- repo: http://github.com/PyCQA/isort
|
||||
rev: 5.8.0
|
||||
hooks:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: 'v0.902'
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: prototorch
|
||||
additional_dependencies: [types-pkg_resources]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: 'v0.31.0' # Use the sha / tag you want to point at
|
||||
hooks:
|
||||
- id: yapf
|
||||
|
||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||
rev: v1.9.0 # Use the ref you want to point at
|
||||
hooks:
|
||||
- id: python-use-type-annotations
|
||||
- id: python-no-log-warn
|
||||
- id: python-check-blanket-noqa
|
||||
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.19.4
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/jorisroovers/gitlint
|
||||
rev: "v0.15.1"
|
||||
hooks:
|
||||
- id: gitlint
|
||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
13
README.md
13
README.md
@ -18,6 +18,19 @@ pip install prototorch_models
|
||||
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
|
||||
be available for use in your Python environment as `prototorch.models`.
|
||||
|
||||
## Contribution
|
||||
|
||||
This repository contains definition for [git hooks](https://githooks.com).
|
||||
[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
|
||||
Please install the hooks by running
|
||||
```bash
|
||||
pre-commit install
|
||||
pre-commit install --hook-type commit-msg
|
||||
```
|
||||
before creating the first commit.
|
||||
|
||||
|
||||
|
||||
## Available models
|
||||
|
||||
### LVQ Family
|
||||
|
@ -70,4 +70,3 @@
|
||||
year="2018",
|
||||
publisher="Springer International Publishing",
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,20 @@
|
||||
"""GMLVQ example using the MNIST dataset."""
|
||||
|
||||
from prototorch.models import ImageGMLVQ
|
||||
from prototorch.models.data import train_on_mnist
|
||||
import torch
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
|
||||
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ):
|
||||
"""Model Definition."""
|
||||
import prototorch as pt
|
||||
from prototorch.models import ImageGMLVQ
|
||||
from prototorch.models.abstract import PrototypeModel
|
||||
from prototorch.models.data import MNISTDataModule
|
||||
|
||||
|
||||
cli = LightningCLI(GMLVQMNIST)
|
||||
class ExperimentClass(ImageGMLVQ):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.zeros(28 * 28),
|
||||
**kwargs)
|
||||
|
||||
|
||||
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)
|
||||
|
@ -1,9 +1,11 @@
|
||||
model:
|
||||
hparams:
|
||||
input_dim: 784
|
||||
latent_dim: 784
|
||||
distribution:
|
||||
num_classes: 10
|
||||
prototypes_per_class: 2
|
||||
input_dim: 784
|
||||
latent_dim: 784
|
||||
proto_lr: 0.01
|
||||
bb_lr: 0.01
|
||||
data:
|
||||
batch_size: 32
|
||||
|
@ -2,10 +2,11 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -2,11 +2,12 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -2,10 +2,11 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -2,10 +2,11 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -2,11 +2,12 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -2,11 +2,12 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
def hex_to_rgb(hex_values):
|
||||
for v in hex_values:
|
||||
|
@ -2,10 +2,11 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -3,10 +3,11 @@
|
||||
import argparse
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
def plot_matrix(matrix):
|
||||
title = "Lambda matrix"
|
||||
|
@ -2,13 +2,14 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -2,11 +2,12 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torchvision.transforms import Lambda
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -1,10 +1,19 @@
|
||||
import prototorch as pt
|
||||
"""Prototorch Data Modules
|
||||
|
||||
This allows to store the used dataset inside a Lightning Module.
|
||||
Mainly used for PytorchLightningCLI configurations.
|
||||
"""
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
# MNIST
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def __init__(self, batch_size=32):
|
||||
super().__init__()
|
||||
@ -49,15 +58,67 @@ class MNISTDataModule(pl.LightningDataModule):
|
||||
return mnist_test
|
||||
|
||||
|
||||
def train_on_mnist(batch_size=256) -> type:
|
||||
class DataClass(pl.LightningModule):
|
||||
datamodule = MNISTDataModule(batch_size=batch_size)
|
||||
# def train_on_mnist(batch_size=256) -> type:
|
||||
# class DataClass(pl.LightningModule):
|
||||
# datamodule = MNISTDataModule(batch_size=batch_size)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
prototype_initializer = kwargs.pop(
|
||||
"prototype_initializer", pt.components.Zeros((28, 28, 1)))
|
||||
super().__init__(*args,
|
||||
prototype_initializer=prototype_initializer,
|
||||
**kwargs)
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# prototype_initializer = kwargs.pop(
|
||||
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
|
||||
# super().__init__(*args,
|
||||
# prototype_initializer=prototype_initializer,
|
||||
# **kwargs)
|
||||
|
||||
return DataClass
|
||||
# dc: Type[DataClass] = DataClass
|
||||
# return dc
|
||||
|
||||
|
||||
# ABSTRACT
|
||||
class GeneralDataModule(pl.LightningDataModule):
|
||||
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
|
||||
super().__init__()
|
||||
self.train_dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return DataLoader(self.train_dataset, batch_size=self.batch_size)
|
||||
|
||||
|
||||
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
|
||||
# class DataClass(pl.LightningModule):
|
||||
# datamodule = GeneralDataModule(dataset, batch_size)
|
||||
# datashape = dataset[0][0].shape
|
||||
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
|
||||
|
||||
# def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
# prototype_initializer = kwargs.pop(
|
||||
# "prototype_initializer",
|
||||
# pt.components.Zeros(self.datashape),
|
||||
# )
|
||||
# super().__init__(*args,
|
||||
# prototype_initializer=prototype_initializer,
|
||||
# **kwargs)
|
||||
|
||||
# return DataClass
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# from prototorch.models import GLVQ
|
||||
|
||||
# demo_dataset = pt.datasets.Iris()
|
||||
|
||||
# TrainingClass: Type = train_on_dataset(demo_dataset)
|
||||
|
||||
# class DemoGLVQ(TrainingClass, GLVQ):
|
||||
# """Model Definition."""
|
||||
|
||||
# # Hyperparameters
|
||||
# hparams = dict(
|
||||
# distribution={
|
||||
# "num_classes": 3,
|
||||
# "prototypes_per_class": 4
|
||||
# },
|
||||
# lr=0.01,
|
||||
# )
|
||||
|
||||
# initialized = DemoGLVQ(hparams)
|
||||
# print(initialized)
|
||||
|
@ -3,11 +3,8 @@
|
||||
import torch
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from prototorch.functions.distances import (lomega_distance, omega_distance,
|
||||
squared_euclidean_distance)
|
||||
from prototorch.functions.helper import get_flat
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from prototorch.modules import LambdaLayer, LossLayer
|
||||
|
15
setup.py
15
setup.py
@ -1,10 +1,12 @@
|
||||
"""
|
||||
_____ _ _______ _
|
||||
| __ \ | | |__ __| | |
|
||||
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
|
||||
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
|
||||
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|
||||
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|Plugin
|
||||
|
||||
######
|
||||
# # ##### #### ##### #### ##### #### ##### #### # #
|
||||
# # # # # # # # # # # # # # # # # #
|
||||
###### # # # # # # # # # # # # # ######
|
||||
# ##### # # # # # # # # ##### # # #
|
||||
# # # # # # # # # # # # # # # # #
|
||||
# # # #### # #### # #### # # #### # #Plugin
|
||||
|
||||
ProtoTorch models Plugin Package
|
||||
"""
|
||||
@ -29,6 +31,7 @@ CLI = [
|
||||
]
|
||||
DEV = [
|
||||
"bumpversion",
|
||||
"pre-commit",
|
||||
]
|
||||
DOCS = [
|
||||
"recommonmark",
|
||||
|
Loading…
Reference in New Issue
Block a user