prototorch_models/prototorch/models/callbacks.py

15 lines
366 B
Python
Raw Normal View History

2021-05-13 13:22:01 +00:00
"""Callbacks for Pytorch Lighning Modules"""
import pytorch_lightning as pl
import torch
class StopOnNaN(pl.Callback):
def __init__(self, param):
super().__init__()
self.param = param
def on_epoch_end(self, trainer, pl_module, logs={}):
if torch.isnan(self.param).any():
raise ValueError("NaN encountered. Stopping.")