.. Copyright 2021 The HuggingFace Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Debugging ======================================================================================================================= Underflow and Overflow Detection ----------------------------------------------------------------------------------------------------------------------- .. note:: This feature is currently available for PyTorch-only. .. note:: For multi-GPU training it requires DDP (``torch.distributed.launch``). .. note:: This feature can be used with any ``nn.Module``-based model. If you start getting ``loss=NaN`` or the model inhibits some other abnormal behavior due to ``inf`` or ``nan`` in activations or weights one needs to discover where the first underflow or overflow happens and what led to it. Luckily you can accomplish that easily by activating a special module that will do the detection automatically. If you're using :class:`~transformers.Trainer`, you just need to add: .. code-block:: bash --debug underflow_overflow to the normal command line arguments, or pass ``debug="underflow_overflow"`` when creating the :class:`~transformers.TrainingArguments` object. If you're using your own training loop or another Trainer you can accomplish the same with: .. code-block:: python from .debug_utils import DebugUnderflowOverflow debug_overflow = DebugUnderflowOverflow(model) :class:`~transformers.debug_utils.DebugUnderflowOverflow` inserts hooks into the model that immediately after each forward call will test input and output variables and also the corresponding module's weights. As soon as ``inf`` or ``nan`` is detected in at least one element of the activations or weights, the program will assert and print a report like this (this was caught with ``google/mt5-small`` under fp16 mixed precision): .. code-block:: Detected inf/nan during batch_number=0 Last 21 forward frames: abs min abs max metadata encoder.block.1.layer.1.DenseReluDense.dropout Dropout 0.00e+00 2.57e+02 input[0] 0.00e+00 2.85e+02 output [...] encoder.block.2.layer.0 T5LayerSelfAttention 6.78e-04 3.15e+03 input[0] 2.65e-04 3.42e+03 output[0] None output[1] 2.25e-01 1.00e+04 output[2] encoder.block.2.layer.1.layer_norm T5LayerNorm 8.69e-02 4.18e-01 weight 2.65e-04 3.42e+03 input[0] 1.79e-06 4.65e+00 output encoder.block.2.layer.1.DenseReluDense.wi_0 Linear 2.17e-07 4.50e+00 weight 1.79e-06 4.65e+00 input[0] 2.68e-06 3.70e+01 output encoder.block.2.layer.1.DenseReluDense.wi_1 Linear 8.08e-07 2.66e+01 weight 1.79e-06 4.65e+00 input[0] 1.27e-04 2.37e+02 output encoder.block.2.layer.1.DenseReluDense.dropout Dropout 0.00e+00 8.76e+03 input[0] 0.00e+00 9.74e+03 output encoder.block.2.layer.1.DenseReluDense.wo Linear 1.01e-06 6.44e+00 weight 0.00e+00 9.74e+03 input[0] 3.18e-04 6.27e+04 output encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense 1.79e-06 4.65e+00 input[0] 3.18e-04 6.27e+04 output encoder.block.2.layer.1.dropout Dropout 3.18e-04 6.27e+04 input[0] 0.00e+00 inf output The example output has been trimmed in the middle for brevity. The second column shows the value of the absolute largest element, so if you have a closer look at the last few frames, the inputs and outputs were in the range of ``1e4``. So when this training was done under fp16 mixed precision the very last step overflowed (since under ``fp16`` the largest number before ``inf`` is ``64e3``). To avoid overflows under ``fp16`` the activations must remain way below ``1e4``, because ``1e4 * 1e4 = 1e8`` so any matrix multiplication with large activations is going to lead to a numerical overflow condition. At the very start of the trace you can discover at which batch number the problem occurred (here ``Detected inf/nan during batch_number=0`` means the problem occurred on the first batch). Each reported frame starts by declaring the fully qualified entry for the corresponding module this frame is reporting for. If we look just at this frame: .. code-block:: encoder.block.2.layer.1.layer_norm T5LayerNorm 8.69e-02 4.18e-01 weight 2.65e-04 3.42e+03 input[0] 1.79e-06 4.65e+00 output Here, ``encoder.block.2.layer.1.layer_norm`` indicates that it was a layer norm for the first layer, of the second block of the encoder. And the specific calls of the ``forward`` is ``T5LayerNorm``. Let's look at the last few frames of that report: .. code-block:: Detected inf/nan during batch_number=0 Last 21 forward frames: abs min abs max metadata [...] encoder.block.2.layer.1.DenseReluDense.wi_0 Linear 2.17e-07 4.50e+00 weight 1.79e-06 4.65e+00 input[0] 2.68e-06 3.70e+01 output encoder.block.2.layer.1.DenseReluDense.wi_1 Linear 8.08e-07 2.66e+01 weight 1.79e-06 4.65e+00 input[0] 1.27e-04 2.37e+02 output encoder.block.2.layer.1.DenseReluDense.wo Linear 1.01e-06 6.44e+00 weight 0.00e+00 9.74e+03 input[0] 3.18e-04 6.27e+04 output encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense 1.79e-06 4.65e+00 input[0] 3.18e-04 6.27e+04 output encoder.block.2.layer.1.dropout Dropout 3.18e-04 6.27e+04 input[0] 0.00e+00 inf output The last frame reports for ``Dropout.forward`` function with the first entry for the only input and the second for the only output. You can see that it was called from an attribute ``dropout`` inside ``DenseReluDense`` class. We can see that it happened during the first layer, of the 2nd block, during the very first batch. Finally, the absolute largest input elements was ``6.27e+04`` and same for the output was ``inf``. You can see here, that ``T5DenseGatedGeluDense.forward`` resulted in output activations, whose absolute max value was around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have ``Dropout`` which renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than 64K, and we get an overlow (``inf``). As you can see it's the previous frames that we need to look into when the numbers start going into very large for fp16 numbers. Let's match the report to the code from ``models/t5/modeling_t5.py``: .. code-block:: python class T5DenseGatedGeluDense(nn.Module): def __init__(self, config): super().__init__() self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) self.dropout = nn.Dropout(config.dropout_rate) self.gelu_act = ACT2FN["gelu_new"] def forward(self, hidden_states): hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) hidden_states = self.wo(hidden_states) return hidden_states Now it's easy to see the ``dropout`` call, and all the previous calls as well. Since the detection is happening in a forward hook, these reports are printed immediately after each ``forward`` returns. Going back to the full report, to act on it and to fix the problem, we need to go a few frames up where the numbers started to go up and most likely switch to the ``fp32`` mode here, so that the numbers don't overflow when multiplied or summed up. Of course, there might be other solutions. For example, we could turn off ``amp`` temporarily if it's enabled, after moving the original ``forward`` into a helper wrapper, like so: .. code-block:: python def _forward(self, hidden_states): hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) hidden_states = self.wo(hidden_states) return hidden_states import torch def forward(self, hidden_states): if torch.is_autocast_enabled(): with torch.cuda.amp.autocast(enabled=False): return self._forward(hidden_states) else: return self._forward(hidden_states) Since the automatic detector only reports on inputs and outputs of full frames, once you know where to look, you may want to analyse the intermediary stages of any specific ``forward`` function as well. In such a case you can use the ``detect_overflow`` helper function to inject the detector where you want it, for example: .. code-block:: python from debug_utils import detect_overflow class T5LayerFF(nn.Module): [...] def forward(self, hidden_states): forwarded_states = self.layer_norm(hidden_states) detect_overflow(forwarded_states, "after layer_norm") forwarded_states = self.DenseReluDense(forwarded_states) detect_overflow(forwarded_states, "after DenseReluDense") return hidden_states + self.dropout(forwarded_states) You can see that we added 2 of these and now we track if ``inf`` or ``nan`` for ``forwarded_states`` was detected somewhere in between. Actually, the detector already reports these because each of the calls in the example above is a `nn.Module``, but let's say if you had some local direct calculations this is how you'd do that. Additionally, if you're instantiating the debugger in your own code, you can adjust the number of frames printed from its default, e.g.: .. code-block:: python from .debug_utils import DebugUnderflowOverflow debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100) Specific batch absolute mix and max value tracing ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The same debugging class can be used for per-batch tracing with the underflow/overflow detection feature turned off. Let's say you want to watch the absolute min and max values for all the ingredients of each ``forward`` call of a given batch, and only do that for batches 1 and 3. Then you instantiate this class as: .. code-block:: python debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3]) And now full batches 1 and 3 will be traced using the same format as the underflow/overflow detector does. Batches are 0-indexed. This is helpful if you know that the program starts misbehaving after a certain batch number, so you can fast-forward right to that area. Here is a sample truncated output for such configuration: .. code-block:: *** Starting batch number=1 *** abs min abs max metadata shared Embedding 1.01e-06 7.92e+02 weight 0.00e+00 2.47e+04 input[0] 5.36e-05 7.92e+02 output [...] decoder.dropout Dropout 1.60e-07 2.27e+01 input[0] 0.00e+00 2.52e+01 output decoder T5Stack not a tensor output lm_head Linear 1.01e-06 7.92e+02 weight 0.00e+00 1.11e+00 input[0] 6.06e-02 8.39e+01 output T5ForConditionalGeneration not a tensor output *** Starting batch number=3 *** abs min abs max metadata shared Embedding 1.01e-06 7.92e+02 weight 0.00e+00 2.78e+04 input[0] 5.36e-05 7.92e+02 output [...] Here you will get a huge number of frames dumped - as many as there were forward calls in your model, so it may or may not what you want, but sometimes it can be easier to use for debugging purposes than a normal debugger. For example, if a problem starts happening at batch number 150. So you can dump traces for batches 149 and 150 and compare where numbers started to diverge. You can also specify the batch number after which to stop the training, with: .. code-block:: python debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3)