diff --git a/examples/cli/gmlvq.py b/examples/cli/gmlvq.py index 104fa31..d0c7e31 100644 --- a/examples/cli/gmlvq.py +++ b/examples/cli/gmlvq.py @@ -1,4 +1,4 @@ -"""GMLVQ example using the MNIST dataset.""" +"""GLVQ example using the MNIST dataset.""" from prototorch.models import ImageGLVQ from pytorch_lightning.utilities.cli import LightningCLI @@ -6,8 +6,8 @@ from pytorch_lightning.utilities.cli import LightningCLI from mnist import TrainOnMNIST -class Model(TrainOnMNIST, ImageGLVQ): - """Model Definition""" +class GLVQMNIST(TrainOnMNIST, ImageGLVQ): + """Model Definition.""" -cli = LightningCLI(Model) +cli = LightningCLI(GLVQMNIST) diff --git a/examples/cli/mnist.py b/examples/cli/mnist.py index 7786fc0..307c3b5 100644 --- a/examples/cli/mnist.py +++ b/examples/cli/mnist.py @@ -1,5 +1,3 @@ -"""GMLVQ example using the MNIST dataset.""" - import prototorch as pt import pytorch_lightning as pl from torch.utils.data import DataLoader, random_split @@ -22,22 +20,21 @@ class MNISTDataModule(pl.LightningDataModule): # OPTIONAL, called for every GPU/machine (assigning state is OK) def setup(self, stage=None): - # transforms + # Transforms transform = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) ]) - # split dataset - if stage in (None, 'fit'): + # Split dataset + if stage in (None, "fit"): mnist_train = MNIST("~/datasets", train=True, transform=transform) self.mnist_train, self.mnist_val = random_split( mnist_train, [55000, 5000]) - if stage == (None, 'test'): + if stage == (None, "test"): self.mnist_test = MNIST("~/datasets", train=False, transform=transform) - # return the dataloader for each split + # Return the dataloader for each split def train_dataloader(self): mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size) return mnist_train @@ -52,7 +49,7 @@ class MNISTDataModule(pl.LightningDataModule): class TrainOnMNIST(pl.LightningModule): - datamodule = MNISTDataModule(batch_size=250) + datamodule = MNISTDataModule(batch_size=256) def prototype_initializer(self, **kwargs): return pt.components.Zeros((28, 28, 1))