JustinLin610
update
10b0761
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
logger = logging.getLogger(__name__)
class NanDetector:
"""
Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name
"""
def __init__(self, model, forward=True, backward=True):
self.bhooks = []
self.fhooks = []
self.forward = forward
self.backward = backward
self.named_parameters = list(model.named_parameters())
self.reset()
for name, mod in model.named_modules():
mod.__module_name = name
self.add_hooks(mod)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
# Dump out all model gnorms to enable better debugging
norm = {}
gradients = {}
for name, param in self.named_parameters:
if param.grad is not None:
grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32)
norm[name] = grad_norm.item()
if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any():
gradients[name] = param.grad.data
if len(gradients) > 0:
logger.info("Detected nan/inf grad norm, dumping norms...")
logger.info(f"norms: {norm}")
logger.info(f"gradients: {gradients}")
self.close()
def add_hooks(self, module):
if self.forward:
self.fhooks.append(module.register_forward_hook(self.fhook_fn))
if self.backward:
self.bhooks.append(module.register_backward_hook(self.bhook_fn))
def reset(self):
self.has_printed_f = False
self.has_printed_b = False
def _detect(self, tensor, name, backward):
err = None
if (
torch.is_floating_point(tensor)
# single value tensors (like the loss) will not provide much info
and tensor.numel() >= 2
):
with torch.no_grad():
if torch.isnan(tensor).any():
err = "NaN"
elif torch.isinf(tensor).any():
err = "Inf"
if err is not None:
err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}"
return err
def _apply(self, module, inp, x, backward):
if torch.is_tensor(x):
if isinstance(inp, tuple) and len(inp) > 0:
inp = inp[0]
err = self._detect(x, module.__module_name, backward)
if err is not None:
if torch.is_tensor(inp) and not backward:
err += (
f" input max: {inp.max().item()}, input min: {inp.min().item()}"
)
has_printed_attr = "has_printed_b" if backward else "has_printed_f"
logger.warning(err)
setattr(self, has_printed_attr, True)
elif isinstance(x, dict):
for v in x.values():
self._apply(module, inp, v, backward)
elif isinstance(x, list) or isinstance(x, tuple):
for v in x:
self._apply(module, inp, v, backward)
def fhook_fn(self, module, inp, output):
if not self.has_printed_f:
self._apply(module, inp, output, backward=False)
def bhook_fn(self, module, inp, output):
if not self.has_printed_b:
self._apply(module, inp, output, backward=True)
def close(self):
for hook in self.fhooks + self.bhooks:
hook.remove()