test(githooks): Add githooks for automatic commit checks

This commit is contained in:
Alexander Engelsberger
2021-06-16 16:16:34 +02:00
parent c87ed5ba8b
commit 8956ee75ad
24 changed files with 196 additions and 49 deletions

View File

@@ -5,4 +5,4 @@ Examples in this folder use the experimental [Lightning CLI](https://pytorch-lig
To use the example run
```
python gmlvq.py --config gmlvq.yaml
```
```

View File

@@ -1,12 +1,20 @@
"""GMLVQ example using the MNIST dataset."""
from prototorch.models import ImageGMLVQ
from prototorch.models.data import train_on_mnist
import torch
from pytorch_lightning.utilities.cli import LightningCLI
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ):
"""Model Definition."""
import prototorch as pt
from prototorch.models import ImageGMLVQ
from prototorch.models.abstract import PrototypeModel
from prototorch.models.data import MNISTDataModule
cli = LightningCLI(GMLVQMNIST)
class ExperimentClass(ImageGMLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.zeros(28 * 28),
**kwargs)
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)

View File

@@ -1,9 +1,11 @@
model:
hparams:
input_dim: 784
latent_dim: 784
distribution:
num_classes: 10
prototypes_per_class: 2
input_dim: 784
latent_dim: 784
proto_lr: 0.01
bb_lr: 0.01
data:
batch_size: 32

View File

@@ -2,10 +2,11 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@@ -2,11 +2,12 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@@ -2,10 +2,11 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@@ -2,10 +2,11 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@@ -2,11 +2,12 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from sklearn.datasets import load_iris
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@@ -2,11 +2,12 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
import prototorch as pt
def hex_to_rgb(hex_values):
for v in hex_values:

View File

@@ -2,10 +2,11 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@@ -3,10 +3,11 @@
import argparse
import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
def plot_matrix(matrix):
title = "Lambda matrix"

View File

@@ -2,13 +2,14 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@@ -2,11 +2,12 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torchvision.transforms import Lambda
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()