fix: examples/gmlvq_mnist.py

This commit is contained in:
Jensun Ravichandran 2021-06-21 14:42:28 +02:00
parent 612ee8dc6a
commit 72404f7c4e

View File

@ -2,13 +2,12 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -56,7 +55,7 @@ if __name__ == "__main__":
model = pt.models.ImageGMLVQ( model = pt.models.ImageGMLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SMI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
) )
# Callbacks # Callbacks
@ -95,7 +94,7 @@ if __name__ == "__main__":
], ],
terminate_on_nan=True, terminate_on_nan=True,
weights_summary=None, weights_summary=None,
accelerator="ddp", # accelerator="ddp",
) )
# Training loop # Training loop