Remove normalization transform from cli example
This commit is contained in:
parent
db965541fd
commit
cc49f26b77
@ -1,4 +1,4 @@
|
||||
"""GMLVQ example using the MNIST dataset."""
|
||||
"""GLVQ example using the MNIST dataset."""
|
||||
|
||||
from prototorch.models import ImageGLVQ
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
@ -6,8 +6,8 @@ from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from mnist import TrainOnMNIST
|
||||
|
||||
|
||||
class Model(TrainOnMNIST, ImageGLVQ):
|
||||
"""Model Definition"""
|
||||
class GLVQMNIST(TrainOnMNIST, ImageGLVQ):
|
||||
"""Model Definition."""
|
||||
|
||||
|
||||
cli = LightningCLI(Model)
|
||||
cli = LightningCLI(GLVQMNIST)
|
||||
|
@ -1,5 +1,3 @@
|
||||
"""GMLVQ example using the MNIST dataset."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
@ -22,22 +20,21 @@ class MNISTDataModule(pl.LightningDataModule):
|
||||
|
||||
# OPTIONAL, called for every GPU/machine (assigning state is OK)
|
||||
def setup(self, stage=None):
|
||||
# transforms
|
||||
# Transforms
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
])
|
||||
# split dataset
|
||||
if stage in (None, 'fit'):
|
||||
# Split dataset
|
||||
if stage in (None, "fit"):
|
||||
mnist_train = MNIST("~/datasets", train=True, transform=transform)
|
||||
self.mnist_train, self.mnist_val = random_split(
|
||||
mnist_train, [55000, 5000])
|
||||
if stage == (None, 'test'):
|
||||
if stage == (None, "test"):
|
||||
self.mnist_test = MNIST("~/datasets",
|
||||
train=False,
|
||||
transform=transform)
|
||||
|
||||
# return the dataloader for each split
|
||||
# Return the dataloader for each split
|
||||
def train_dataloader(self):
|
||||
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
|
||||
return mnist_train
|
||||
@ -52,7 +49,7 @@ class MNISTDataModule(pl.LightningDataModule):
|
||||
|
||||
|
||||
class TrainOnMNIST(pl.LightningModule):
|
||||
datamodule = MNISTDataModule(batch_size=250)
|
||||
datamodule = MNISTDataModule(batch_size=256)
|
||||
|
||||
def prototype_initializer(self, **kwargs):
|
||||
return pt.components.Zeros((28, 28, 1))
|
||||
|
Loading…
Reference in New Issue
Block a user