FLARE / flare /utils /debug.py
yzhouchen001's picture
cleaned up
2c0063e
import torch
def nan_hook(self,inp, output):
nan_mask = torch.isnan(output)
if nan_mask.any():
print("In", self.__class__.__name__)
raise RuntimeError(f"Found NAN in output at indices: ", nan_mask.nonzero())
inf_mask = torch.isinf(output)
if inf_mask.any():
print("In", self.__class__.__name__)
raise RuntimeError(f"Found INF in output at indices: ", inf_mask.nonzero())