Spaces:
Sleeping
Sleeping
Audio-Deepfake-Detection
/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
/fairseq
/nan_detector.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import torch | |
logger = logging.getLogger(__name__) | |
class NanDetector: | |
""" | |
Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name | |
""" | |
def __init__(self, model, forward=True, backward=True): | |
self.bhooks = [] | |
self.fhooks = [] | |
self.forward = forward | |
self.backward = backward | |
self.named_parameters = list(model.named_parameters()) | |
self.reset() | |
for name, mod in model.named_modules(): | |
mod.__module_name = name | |
self.add_hooks(mod) | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_value, exc_traceback): | |
# Dump out all model gnorms to enable better debugging | |
norm = {} | |
gradients = {} | |
for name, param in self.named_parameters: | |
if param.grad is not None: | |
grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32) | |
norm[name] = grad_norm.item() | |
if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): | |
gradients[name] = param.grad.data | |
if len(gradients) > 0: | |
logger.info("Detected nan/inf grad norm, dumping norms...") | |
logger.info(f"norms: {norm}") | |
logger.info(f"gradients: {gradients}") | |
self.close() | |
def add_hooks(self, module): | |
if self.forward: | |
self.fhooks.append(module.register_forward_hook(self.fhook_fn)) | |
if self.backward: | |
self.bhooks.append(module.register_backward_hook(self.bhook_fn)) | |
def reset(self): | |
self.has_printed_f = False | |
self.has_printed_b = False | |
def _detect(self, tensor, name, backward): | |
err = None | |
if ( | |
torch.is_floating_point(tensor) | |
# single value tensors (like the loss) will not provide much info | |
and tensor.numel() >= 2 | |
): | |
with torch.no_grad(): | |
if torch.isnan(tensor).any(): | |
err = "NaN" | |
elif torch.isinf(tensor).any(): | |
err = "Inf" | |
if err is not None: | |
err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}" | |
return err | |
def _apply(self, module, inp, x, backward): | |
if torch.is_tensor(x): | |
if isinstance(inp, tuple) and len(inp) > 0: | |
inp = inp[0] | |
err = self._detect(x, module.__module_name, backward) | |
if err is not None: | |
if torch.is_tensor(inp) and not backward: | |
err += ( | |
f" input max: {inp.max().item()}, input min: {inp.min().item()}" | |
) | |
has_printed_attr = "has_printed_b" if backward else "has_printed_f" | |
logger.warning(err) | |
setattr(self, has_printed_attr, True) | |
elif isinstance(x, dict): | |
for v in x.values(): | |
self._apply(module, inp, v, backward) | |
elif isinstance(x, list) or isinstance(x, tuple): | |
for v in x: | |
self._apply(module, inp, v, backward) | |
def fhook_fn(self, module, inp, output): | |
if not self.has_printed_f: | |
self._apply(module, inp, output, backward=False) | |
def bhook_fn(self, module, inp, output): | |
if not self.has_printed_b: | |
self._apply(module, inp, output, backward=True) | |
def close(self): | |
for hook in self.fhooks + self.bhooks: | |
hook.remove() | |