fix: examples/gmlvq_mnist.py
This commit is contained in:
parent
612ee8dc6a
commit
72404f7c4e
@ -2,13 +2,12 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -56,7 +55,7 @@ if __name__ == "__main__":
|
|||||||
model = pt.models.ImageGMLVQ(
|
model = pt.models.ImageGMLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
@ -95,7 +94,7 @@ if __name__ == "__main__":
|
|||||||
],
|
],
|
||||||
terminate_on_nan=True,
|
terminate_on_nan=True,
|
||||||
weights_summary=None,
|
weights_summary=None,
|
||||||
accelerator="ddp",
|
# accelerator="ddp",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
Loading…
Reference in New Issue
Block a user