Spaces:
Runtime error
Runtime error
File size: 8,805 Bytes
217780a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from https://github.com/huggingface/transformers/blob/f93c90d21749b61bd89152a7fe99a839df29ed94/src/transformers/debug_utils.py
"""
import json
from transformers.utils import ExplicitEnum, is_torch_available, logging
from m4.training.utils import get_stats
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class ActivationTracker:
"""
This debug class helps detect and understand where the model starts getting very large or very small, and more
importantly `nan` or `inf` activation elements.
This class will plug hooks into the model and record the activation values of the model into a list of dictionaries: `jsonl_stats`.
Recording is only active during training, not during validation, and when `trace_activation` is set to True.
In practise, since this tracking requires additional computation, we only track activations every X steps.
In the case of gradient accumulation, all the batches being accumulated are being recorded and identified by the `batch_idx` key.
Args:
model (`nn.Module`):
The model to debug.
abort_after_batch_num (`int``, *optional*):
Whether to abort after a certain batch number has finished
"""
def __init__(
self,
model,
abort_after_batch_num=None,
):
self.model = model
self.is_validation = False
self.abort_after_batch_num = abort_after_batch_num
self.jsonl_stats = []
self.batch_number = 0
self.detected_overflow = False
self.analyse_model()
self.register_forward_hook()
def analyse_model(self):
# extract the fully qualified module names, to be able to report at run time. e.g.:
# encoder.block.2.layer.0.SelfAttention.o
#
# for shared weights only the first shared module name will be registered
self.module_names = {m: name for name, m in self.model.named_modules()}
def analyse_variable(self, var, ctx, current_module_stats):
if torch.is_tensor(var):
dict_stats = get_stats(var, ctx)
current_module_stats.update(dict_stats)
# self.expand_frame(text_stats)
if detect_overflow(var, ctx):
self.detected_overflow = True
return current_module_stats
def create_frame(self, module, input, output):
module_name = f"{self.module_names[module]}"
module_type = f"{module.__class__.__name__}"
current_module_stats = {}
# inputs
if isinstance(input, tuple):
for i, x in enumerate(input):
current_module_stats = self.analyse_variable(x, f"input[{i}]", current_module_stats)
else:
current_module_stats = self.analyse_variable(input, "input", current_module_stats)
# outputs
if isinstance(output, tuple):
for i, x in enumerate(output):
# possibly a tuple of tuples
if isinstance(x, tuple):
for j, y in enumerate(x):
current_module_stats = self.analyse_variable(y, f"output[{i}][{j}]", current_module_stats)
else:
current_module_stats = self.analyse_variable(x, f"output[{i}]", current_module_stats)
else:
current_module_stats = self.analyse_variable(output, "output", current_module_stats)
if current_module_stats:
# When we activate gradient checkpointing, the forward hook will be called twice for some (not all) modules.
# That will lead to double (repeated) entries in the list.
# This is a hack to avoid these double entries.
if (module_name, module_type) not in [(x["name"], x["type"]) for x in self.jsonl_stats]:
self.jsonl_stats.append(
{
"name": module_name,
"type": module_type,
**current_module_stats,
}
)
def register_forward_hook(self):
self.model.apply(self._register_forward_hook)
def _register_forward_hook(self, module):
module.register_forward_hook(self.forward_hook)
def forward_hook(self, module, input, output):
# - input is a tuple of packed inputs (could be non-Tensors)
# - output could be a Tensor or a tuple of Tensors and non-Tensors
trace_activation = self.trace_activation
# count batch numbers - the very first forward hook of the batch will be called when the
# batch completes - i.e. it gets called very last - we know this batch has finished
if module == self.model:
self.batch_number += 1
if trace_activation and not self.is_validation:
self.create_frame(module, input, output)
if self.detected_overflow:
# now we can abort, as it's pointless to continue running
raise ValueError(
"DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
"Please scroll up above this traceback to see the activation values prior to this event."
)
# abort after certain batch if requested to do so
if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
raise ValueError(
f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
)
def fill_in_batch_idx(self, batch_idx):
if not self.jsonl_stats:
return
for r in self.jsonl_stats:
if "batch_idx" not in r:
r["batch_idx"] = batch_idx
else:
if not (r["batch_idx"] <= batch_idx):
raise ValueError("`batch_idx` should be increasing")
def dump_stats(self, log_activations_filename, curr_opt_step):
with open(log_activations_filename, "a") as file:
# append stats to file
for r in self.jsonl_stats:
r["step"] = curr_opt_step
file.write(json.dumps(r) + "\n")
def reset_jsonl_stats(self):
self.jsonl_stats = []
def activate_hooks(self):
self.trace_activation = True
def deactivate_hooks(self):
self.trace_activation = False
def is_eval(self):
self.is_validation = True
def is_train(self):
self.is_validation = False
def detect_overflow(var, ctx):
"""
Report whether the tensor contains any `nan` or `inf` entries.
This is useful for detecting overflows/underflows and best to call right after the function that did some math that
modified the tensor in question.
This function contains a few other helper features that you can enable and tweak directly if you want to track
various other things.
Args:
var: the tensor variable to check
ctx: the message to print as a context
Return:
`True` if `inf` or `nan` was detected, `False` otherwise
"""
detected = False
if torch.isnan(var).any().item():
detected = True
print(f"{ctx} has nans")
if torch.isinf(var).any().item():
detected = True
print(f"{ctx} has infs")
# if needed to monitor large elements can enable the following
if 0: # and detected:
n100 = var[torch.ge(var.abs(), 100)]
if n100.numel() > 0:
print(f"{ctx}: n100={n100.numel()}")
n1000 = var[torch.ge(var.abs(), 1000)]
if n1000.numel() > 0:
print(f"{ctx}: n1000={n1000.numel()}")
n10000 = var[torch.ge(var.abs(), 10000)]
if n10000.numel() > 0:
print(f"{ctx}: n10000={n10000.numel()}")
if 0:
print(f"min={var.min():9.2e} max={var.max():9.2e}")
if 0:
print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
return detected
class DebugOption(ExplicitEnum):
UNDERFLOW_OVERFLOW = "underflow_overflow"
TPU_METRICS_DEBUG = "tpu_metrics_debug"
|