Examples use GPUs if available.

This commit is contained in:
Alexander Engelsberger
2021-05-13 15:22:01 +02:00
parent 8f9c29bd2b
commit 0eac2ce326
14 changed files with 56 additions and 39 deletions

View File

@@ -3,17 +3,7 @@
import prototorch as pt
import pytorch_lightning as pl
import torch
class StopOnNaN(pl.Callback):
def __init__(self, param):
super().__init__()
self.param = param
def on_epoch_end(self, trainer, pl_module, logs={}):
if torch.isnan(self.param).any():
raise ValueError("NaN encountered. Stopping.")
from prototorch.models.callbacks import StopOnNaN
if __name__ == "__main__":
# Dataset
@@ -40,11 +30,12 @@ if __name__ == "__main__":
noise=1e-1))
# Callbacks
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=False, block=True)
snan = StopOnNaN(model.proto_layer.components)
# Setup trainer
trainer = pl.Trainer(
gpus=-1,
max_epochs=200,
callbacks=[vis, snan],
)