Add MNIST datamodule and training mixin factory.
This commit is contained in:
parent
b7edee02c3
commit
dade502686
@ -1,12 +1,11 @@
|
|||||||
"""GLVQ example using the MNIST dataset."""
|
"""GLVQ example using the MNIST dataset."""
|
||||||
|
|
||||||
from prototorch.models import ImageGLVQ
|
from prototorch.models import ImageGLVQ
|
||||||
|
from prototorch.models.data import train_on_mnist
|
||||||
from pytorch_lightning.utilities.cli import LightningCLI
|
from pytorch_lightning.utilities.cli import LightningCLI
|
||||||
|
|
||||||
from mnist import TrainOnMNIST
|
|
||||||
|
|
||||||
|
class GLVQMNIST(train_on_mnist(batch_size=64), ImageGLVQ):
|
||||||
class GLVQMNIST(TrainOnMNIST, ImageGLVQ):
|
|
||||||
"""Model Definition."""
|
"""Model Definition."""
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,15 +10,12 @@ class MNISTDataModule(pl.LightningDataModule):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
# When doing distributed training, Datamodules have two optional arguments for
|
# Download mnist dataset as side-effect, only called on the first cpu
|
||||||
# granular control over download/prepare/splitting data:
|
|
||||||
|
|
||||||
# OPTIONAL, called only on 1 GPU/machine
|
|
||||||
def prepare_data(self):
|
def prepare_data(self):
|
||||||
MNIST("~/datasets", train=True, download=True)
|
MNIST("~/datasets", train=True, download=True)
|
||||||
MNIST("~/datasets", train=False, download=True)
|
MNIST("~/datasets", train=False, download=True)
|
||||||
|
|
||||||
# OPTIONAL, called for every GPU/machine (assigning state is OK)
|
# called for every GPU/machine (assigning state is OK)
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
# Transforms
|
# Transforms
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
@ -28,13 +25,17 @@ class MNISTDataModule(pl.LightningDataModule):
|
|||||||
if stage in (None, "fit"):
|
if stage in (None, "fit"):
|
||||||
mnist_train = MNIST("~/datasets", train=True, transform=transform)
|
mnist_train = MNIST("~/datasets", train=True, transform=transform)
|
||||||
self.mnist_train, self.mnist_val = random_split(
|
self.mnist_train, self.mnist_val = random_split(
|
||||||
mnist_train, [55000, 5000])
|
mnist_train,
|
||||||
|
[55000, 5000],
|
||||||
|
)
|
||||||
if stage == (None, "test"):
|
if stage == (None, "test"):
|
||||||
self.mnist_test = MNIST("~/datasets",
|
self.mnist_test = MNIST(
|
||||||
train=False,
|
"~/datasets",
|
||||||
transform=transform)
|
train=False,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
# Return the dataloader for each split
|
# Dataloaders
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
|
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
|
||||||
return mnist_train
|
return mnist_train
|
||||||
@ -48,8 +49,11 @@ class MNISTDataModule(pl.LightningDataModule):
|
|||||||
return mnist_test
|
return mnist_test
|
||||||
|
|
||||||
|
|
||||||
class TrainOnMNIST(pl.LightningModule):
|
def train_on_mnist(batch_size=256) -> type:
|
||||||
datamodule = MNISTDataModule(batch_size=256)
|
class DataClass(pl.LightningModule):
|
||||||
|
datamodule = MNISTDataModule(batch_size=batch_size)
|
||||||
|
|
||||||
def prototype_initializer(self, **kwargs):
|
def prototype_initializer(self, **kwargs):
|
||||||
return pt.components.Zeros((28, 28, 1))
|
return pt.components.Zeros((28, 28, 1))
|
||||||
|
|
||||||
|
return DataClass
|
Loading…
Reference in New Issue
Block a user