Remove normalization transform from cli example

This commit is contained in:
Jensun Ravichandran 2021-05-25 21:13:37 +02:00
parent db965541fd
commit cc49f26b77
2 changed files with 10 additions and 13 deletions

View File

@ -1,4 +1,4 @@
"""GMLVQ example using the MNIST dataset.""" """GLVQ example using the MNIST dataset."""
from prototorch.models import ImageGLVQ from prototorch.models import ImageGLVQ
from pytorch_lightning.utilities.cli import LightningCLI from pytorch_lightning.utilities.cli import LightningCLI
@ -6,8 +6,8 @@ from pytorch_lightning.utilities.cli import LightningCLI
from mnist import TrainOnMNIST from mnist import TrainOnMNIST
class Model(TrainOnMNIST, ImageGLVQ): class GLVQMNIST(TrainOnMNIST, ImageGLVQ):
"""Model Definition""" """Model Definition."""
cli = LightningCLI(Model) cli = LightningCLI(GLVQMNIST)

View File

@ -1,5 +1,3 @@
"""GMLVQ example using the MNIST dataset."""
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
@ -22,22 +20,21 @@ class MNISTDataModule(pl.LightningDataModule):
# OPTIONAL, called for every GPU/machine (assigning state is OK) # OPTIONAL, called for every GPU/machine (assigning state is OK)
def setup(self, stage=None): def setup(self, stage=None):
# transforms # Transforms
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]) ])
# split dataset # Split dataset
if stage in (None, 'fit'): if stage in (None, "fit"):
mnist_train = MNIST("~/datasets", train=True, transform=transform) mnist_train = MNIST("~/datasets", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split( self.mnist_train, self.mnist_val = random_split(
mnist_train, [55000, 5000]) mnist_train, [55000, 5000])
if stage == (None, 'test'): if stage == (None, "test"):
self.mnist_test = MNIST("~/datasets", self.mnist_test = MNIST("~/datasets",
train=False, train=False,
transform=transform) transform=transform)
# return the dataloader for each split # Return the dataloader for each split
def train_dataloader(self): def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size) mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return mnist_train return mnist_train
@ -52,7 +49,7 @@ class MNISTDataModule(pl.LightningDataModule):
class TrainOnMNIST(pl.LightningModule): class TrainOnMNIST(pl.LightningModule):
datamodule = MNISTDataModule(batch_size=250) datamodule = MNISTDataModule(batch_size=256)
def prototype_initializer(self, **kwargs): def prototype_initializer(self, **kwargs):
return pt.components.Zeros((28, 28, 1)) return pt.components.Zeros((28, 28, 1))