fix: examples/gmlvq_mnist.py

This commit is contained in:
Jensun Ravichandran 2021-06-21 14:42:28 +02:00
parent 612ee8dc6a
commit 72404f7c4e

View File

@ -2,13 +2,12 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torchvision import transforms
from torchvision.datasets import MNIST
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
@ -56,7 +55,7 @@ if __name__ == "__main__":
model = pt.models.ImageGMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SMI(train_ds),
prototypes_initializer=pt.initializers.SMCI(train_ds),
)
# Callbacks
@ -95,7 +94,7 @@ if __name__ == "__main__":
],
terminate_on_nan=True,
weights_summary=None,
accelerator="ddp",
# accelerator="ddp",
)
# Training loop