Examples use GPUs if available.
This commit is contained in:
@@ -3,17 +3,7 @@
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class StopOnNaN(pl.Callback):
|
||||
def __init__(self, param):
|
||||
super().__init__()
|
||||
self.param = param
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module, logs={}):
|
||||
if torch.isnan(self.param).any():
|
||||
raise ValueError("NaN encountered. Stopping.")
|
||||
|
||||
from prototorch.models.callbacks import StopOnNaN
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
@@ -40,11 +30,12 @@ if __name__ == "__main__":
|
||||
noise=1e-1))
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
|
||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=False, block=True)
|
||||
snan = StopOnNaN(model.proto_layer.components)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
gpus=-1,
|
||||
max_epochs=200,
|
||||
callbacks=[vis, snan],
|
||||
)
|
||||
|
Reference in New Issue
Block a user