LightningCLI Example.

This commit is contained in:
Alexander Engelsberger 2021-05-21 17:10:36 +02:00
parent 8ce18f83ce
commit b60db3174a
3 changed files with 80 additions and 0 deletions

9
examples/cli/config.yaml Normal file
View File

@ -0,0 +1,9 @@
model:
hparams:
input_dim: 784
latent_dim: 784
distribution:
num_classes: 10
prototypes_per_class: 2
proto_lr: 0.01
bb_lr: 0.01

13
examples/cli/gmlvq.py Normal file
View File

@ -0,0 +1,13 @@
"""GMLVQ example using the MNIST dataset."""
from prototorch.models import ImageGLVQ
from pytorch_lightning.utilities.cli import LightningCLI
from mnist import TrainOnMNIST
class Model(TrainOnMNIST, ImageGLVQ):
"""Model Definition"""
cli = LightningCLI(Model)

58
examples/cli/mnist.py Normal file
View File

@ -0,0 +1,58 @@
"""GMLVQ example using the MNIST dataset."""
import prototorch as pt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
# When doing distributed training, Datamodules have two optional arguments for
# granular control over download/prepare/splitting data:
# OPTIONAL, called only on 1 GPU/machine
def prepare_data(self):
MNIST("~/datasets", train=True, download=True)
MNIST("~/datasets", train=False, download=True)
# OPTIONAL, called for every GPU/machine (assigning state is OK)
def setup(self, stage=None):
# transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])
# split dataset
if stage in (None, 'fit'):
mnist_train = MNIST("~/datasets", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(
mnist_train, [55000, 5000])
if stage == (None, 'test'):
self.mnist_test = MNIST("~/datasets",
train=False,
transform=transform)
# return the dataloader for each split
def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return mnist_train
def val_dataloader(self):
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
return mnist_val
def test_dataloader(self):
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
return mnist_test
class TrainOnMNIST(pl.LightningModule):
datamodule = MNISTDataModule(batch_size=250)
def prototype_initializer(self, **kwargs):
return pt.components.Zeros((28, 28, 1))