Examples use GPUs if available.
This commit is contained in:
14
prototorch/models/callbacks.py
Normal file
14
prototorch/models/callbacks.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Callbacks for Pytorch Lighning Modules"""
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class StopOnNaN(pl.Callback):
|
||||
def __init__(self, param):
|
||||
super().__init__()
|
||||
self.param = param
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module, logs={}):
|
||||
if torch.isnan(self.param).any():
|
||||
raise ValueError("NaN encountered. Stopping.")
|
Reference in New Issue
Block a user