test(githooks): Add githooks for automatic commit checks
This commit is contained in:
@@ -5,4 +5,4 @@ Examples in this folder use the experimental [Lightning CLI](https://pytorch-lig
|
||||
To use the example run
|
||||
```
|
||||
python gmlvq.py --config gmlvq.yaml
|
||||
```
|
||||
```
|
||||
|
@@ -1,12 +1,20 @@
|
||||
"""GMLVQ example using the MNIST dataset."""
|
||||
|
||||
from prototorch.models import ImageGMLVQ
|
||||
from prototorch.models.data import train_on_mnist
|
||||
import torch
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
|
||||
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ):
|
||||
"""Model Definition."""
|
||||
import prototorch as pt
|
||||
from prototorch.models import ImageGMLVQ
|
||||
from prototorch.models.abstract import PrototypeModel
|
||||
from prototorch.models.data import MNISTDataModule
|
||||
|
||||
|
||||
cli = LightningCLI(GMLVQMNIST)
|
||||
class ExperimentClass(ImageGMLVQ):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.zeros(28 * 28),
|
||||
**kwargs)
|
||||
|
||||
|
||||
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)
|
||||
|
@@ -1,9 +1,11 @@
|
||||
model:
|
||||
hparams:
|
||||
input_dim: 784
|
||||
latent_dim: 784
|
||||
distribution:
|
||||
num_classes: 10
|
||||
prototypes_per_class: 2
|
||||
input_dim: 784
|
||||
latent_dim: 784
|
||||
proto_lr: 0.01
|
||||
bb_lr: 0.01
|
||||
data:
|
||||
batch_size: 32
|
||||
|
Reference in New Issue
Block a user