Add MNIST datamodule and training mixin factory.

This commit is contained in:
Alexander Engelsberger 2021-05-28 16:33:31 +02:00
parent b7edee02c3
commit dade502686
2 changed files with 20 additions and 17 deletions

View File

@ -1,12 +1,11 @@
"""GLVQ example using the MNIST dataset.""" """GLVQ example using the MNIST dataset."""
from prototorch.models import ImageGLVQ from prototorch.models import ImageGLVQ
from prototorch.models.data import train_on_mnist
from pytorch_lightning.utilities.cli import LightningCLI from pytorch_lightning.utilities.cli import LightningCLI
from mnist import TrainOnMNIST
class GLVQMNIST(train_on_mnist(batch_size=64), ImageGLVQ):
class GLVQMNIST(TrainOnMNIST, ImageGLVQ):
"""Model Definition.""" """Model Definition."""

View File

@ -10,15 +10,12 @@ class MNISTDataModule(pl.LightningDataModule):
super().__init__() super().__init__()
self.batch_size = batch_size self.batch_size = batch_size
# When doing distributed training, Datamodules have two optional arguments for # Download mnist dataset as side-effect, only called on the first cpu
# granular control over download/prepare/splitting data:
# OPTIONAL, called only on 1 GPU/machine
def prepare_data(self): def prepare_data(self):
MNIST("~/datasets", train=True, download=True) MNIST("~/datasets", train=True, download=True)
MNIST("~/datasets", train=False, download=True) MNIST("~/datasets", train=False, download=True)
# OPTIONAL, called for every GPU/machine (assigning state is OK) # called for every GPU/machine (assigning state is OK)
def setup(self, stage=None): def setup(self, stage=None):
# Transforms # Transforms
transform = transforms.Compose([ transform = transforms.Compose([
@ -28,13 +25,17 @@ class MNISTDataModule(pl.LightningDataModule):
if stage in (None, "fit"): if stage in (None, "fit"):
mnist_train = MNIST("~/datasets", train=True, transform=transform) mnist_train = MNIST("~/datasets", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split( self.mnist_train, self.mnist_val = random_split(
mnist_train, [55000, 5000]) mnist_train,
[55000, 5000],
)
if stage == (None, "test"): if stage == (None, "test"):
self.mnist_test = MNIST("~/datasets", self.mnist_test = MNIST(
train=False, "~/datasets",
transform=transform) train=False,
transform=transform,
)
# Return the dataloader for each split # Dataloaders
def train_dataloader(self): def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size) mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return mnist_train return mnist_train
@ -48,8 +49,11 @@ class MNISTDataModule(pl.LightningDataModule):
return mnist_test return mnist_test
class TrainOnMNIST(pl.LightningModule): def train_on_mnist(batch_size=256) -> type:
datamodule = MNISTDataModule(batch_size=256) class DataClass(pl.LightningModule):
datamodule = MNISTDataModule(batch_size=batch_size)
def prototype_initializer(self, **kwargs): def prototype_initializer(self, **kwargs):
return pt.components.Zeros((28, 28, 1)) return pt.components.Zeros((28, 28, 1))
return DataClass