Create trainer_pt_utils.py
Browse files- trainer_pt_utils.py +1106 -0
trainer_pt_utils.py
ADDED
@@ -0,0 +1,1106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020-present the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Torch utilities for the Trainer class.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import datetime
|
20 |
+
import json
|
21 |
+
import math
|
22 |
+
import os
|
23 |
+
import sys
|
24 |
+
import warnings
|
25 |
+
from collections.abc import Mapping
|
26 |
+
from contextlib import contextmanager
|
27 |
+
from dataclasses import dataclass
|
28 |
+
from logging import StreamHandler
|
29 |
+
from typing import Any, Dict, Iterator, List, Optional, Union
|
30 |
+
|
31 |
+
import numpy as np
|
32 |
+
import torch
|
33 |
+
import torch.distributed as dist
|
34 |
+
from torch import nn
|
35 |
+
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
|
36 |
+
from torch.utils.data.distributed import DistributedSampler
|
37 |
+
|
38 |
+
from .tokenization_utils_base import BatchEncoding
|
39 |
+
from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging
|
40 |
+
|
41 |
+
|
42 |
+
if is_training_run_on_sagemaker():
|
43 |
+
logging.add_handler(StreamHandler(sys.stdout))
|
44 |
+
|
45 |
+
if is_torch_tpu_available(check_device=False):
|
46 |
+
import torch_xla.core.xla_model as xm
|
47 |
+
|
48 |
+
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
|
49 |
+
try:
|
50 |
+
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
|
51 |
+
except ImportError:
|
52 |
+
SAVE_STATE_WARNING = ""
|
53 |
+
|
54 |
+
logger = logging.get_logger(__name__)
|
55 |
+
|
56 |
+
|
57 |
+
def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
|
58 |
+
if isinstance(tensor_or_array, torch.Tensor):
|
59 |
+
if hasattr(torch, "atleast_1d"):
|
60 |
+
tensor_or_array = torch.atleast_1d(tensor_or_array)
|
61 |
+
elif tensor_or_array.ndim < 1:
|
62 |
+
tensor_or_array = tensor_or_array[None]
|
63 |
+
else:
|
64 |
+
tensor_or_array = np.atleast_1d(tensor_or_array)
|
65 |
+
return tensor_or_array
|
66 |
+
|
67 |
+
|
68 |
+
def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
|
69 |
+
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
|
70 |
+
tensor1 = atleast_1d(tensor1)
|
71 |
+
tensor2 = atleast_1d(tensor2)
|
72 |
+
|
73 |
+
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
|
74 |
+
return torch.cat((tensor1, tensor2), dim=0)
|
75 |
+
|
76 |
+
# Let's figure out the new shape
|
77 |
+
new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:]
|
78 |
+
|
79 |
+
# Now let's fill the result tensor
|
80 |
+
result = tensor1.new_full(new_shape, padding_index)
|
81 |
+
result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1
|
82 |
+
result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2
|
83 |
+
return result
|
84 |
+
|
85 |
+
|
86 |
+
def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
|
87 |
+
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
|
88 |
+
array1 = atleast_1d(array1)
|
89 |
+
array2 = atleast_1d(array2)
|
90 |
+
|
91 |
+
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
|
92 |
+
return np.concatenate((array1, array2), axis=0)
|
93 |
+
|
94 |
+
# Let's figure out the new shape
|
95 |
+
new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:]
|
96 |
+
|
97 |
+
# Now let's fill the result tensor
|
98 |
+
result = np.full_like(array1, padding_index, shape=new_shape)
|
99 |
+
result[: array1.shape[0], : array1.shape[1]] = array1
|
100 |
+
result[array1.shape[0] :, : array2.shape[1]] = array2
|
101 |
+
return result
|
102 |
+
|
103 |
+
|
104 |
+
def nested_concat(tensors, new_tensors, padding_index=-100):
|
105 |
+
"""
|
106 |
+
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
|
107 |
+
nested list/tuples/dict of tensors.
|
108 |
+
"""
|
109 |
+
assert type(tensors) == type(
|
110 |
+
new_tensors
|
111 |
+
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
112 |
+
if isinstance(tensors, (list, tuple)):
|
113 |
+
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
|
114 |
+
elif isinstance(tensors, torch.Tensor):
|
115 |
+
return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
|
116 |
+
elif isinstance(tensors, Mapping):
|
117 |
+
return type(tensors)(
|
118 |
+
{k: nested_concat(t, new_tensors[k], padding_index=padding_index) for k, t in tensors.items()}
|
119 |
+
)
|
120 |
+
elif isinstance(tensors, np.ndarray):
|
121 |
+
return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
|
122 |
+
else:
|
123 |
+
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
|
124 |
+
|
125 |
+
|
126 |
+
def find_batch_size(tensors):
|
127 |
+
"""
|
128 |
+
Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
|
129 |
+
"""
|
130 |
+
if isinstance(tensors, (list, tuple)):
|
131 |
+
for t in tensors:
|
132 |
+
result = find_batch_size(t)
|
133 |
+
if result is not None:
|
134 |
+
return result
|
135 |
+
elif isinstance(tensors, Mapping):
|
136 |
+
for key, value in tensors.items():
|
137 |
+
result = find_batch_size(value)
|
138 |
+
if result is not None:
|
139 |
+
return result
|
140 |
+
elif isinstance(tensors, torch.Tensor):
|
141 |
+
return tensors.shape[0] if len(tensors.shape) >= 1 else None
|
142 |
+
elif isinstance(tensors, np.ndarray):
|
143 |
+
return tensors.shape[0] if len(tensors.shape) >= 1 else None
|
144 |
+
|
145 |
+
|
146 |
+
def nested_numpify(tensors):
|
147 |
+
"Numpify `tensors` (even if it's a nested list/tuple/dict of tensors)."
|
148 |
+
if isinstance(tensors, (list, tuple)):
|
149 |
+
return type(tensors)(nested_numpify(t) for t in tensors)
|
150 |
+
if isinstance(tensors, Mapping):
|
151 |
+
return type(tensors)({k: nested_numpify(t) for k, t in tensors.items()})
|
152 |
+
|
153 |
+
t = tensors.cpu()
|
154 |
+
if t.dtype == torch.bfloat16:
|
155 |
+
# As of Numpy 1.21.4, NumPy does not support bfloat16 (see
|
156 |
+
# https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
|
157 |
+
# Until Numpy adds bfloat16, we must convert float32.
|
158 |
+
t = t.to(torch.float32)
|
159 |
+
return t.numpy()
|
160 |
+
|
161 |
+
|
162 |
+
def nested_detach(tensors):
|
163 |
+
"Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
|
164 |
+
if isinstance(tensors, (list, tuple)):
|
165 |
+
return type(tensors)(nested_detach(t) for t in tensors)
|
166 |
+
elif isinstance(tensors, Mapping):
|
167 |
+
return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
|
168 |
+
return tensors.detach()
|
169 |
+
|
170 |
+
|
171 |
+
def nested_xla_mesh_reduce(tensors, name):
|
172 |
+
if is_torch_tpu_available():
|
173 |
+
import torch_xla.core.xla_model as xm
|
174 |
+
|
175 |
+
if isinstance(tensors, (list, tuple)):
|
176 |
+
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
177 |
+
if isinstance(tensors, Mapping):
|
178 |
+
return type(tensors)(
|
179 |
+
{k: nested_xla_mesh_reduce(t, f"{name}_{i}") for i, (k, t) in enumerate(tensors.items())}
|
180 |
+
)
|
181 |
+
|
182 |
+
tensors = atleast_1d(tensors)
|
183 |
+
return xm.mesh_reduce(name, tensors, torch.cat)
|
184 |
+
else:
|
185 |
+
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
|
186 |
+
|
187 |
+
|
188 |
+
def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> Any:
|
189 |
+
try:
|
190 |
+
if isinstance(tensor, (tuple, list)):
|
191 |
+
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
192 |
+
if isinstance(tensor, Mapping):
|
193 |
+
return type(tensor)({k: distributed_concat(t, num_total_examples) for k, t in tensor.items()})
|
194 |
+
tensor = atleast_1d(tensor).contiguous()
|
195 |
+
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
|
196 |
+
dist.all_gather(output_tensors, tensor)
|
197 |
+
concat = torch.cat(output_tensors, dim=0)
|
198 |
+
|
199 |
+
# truncate the dummy elements added by SequentialDistributedSampler
|
200 |
+
if num_total_examples is not None:
|
201 |
+
concat = concat[:num_total_examples]
|
202 |
+
return concat
|
203 |
+
except AssertionError:
|
204 |
+
raise AssertionError("Not currently using distributed training")
|
205 |
+
|
206 |
+
|
207 |
+
def distributed_broadcast_scalars(
|
208 |
+
scalars: List[Union[int, float]],
|
209 |
+
num_total_examples: Optional[int] = None,
|
210 |
+
device: Optional[torch.device] = torch.device("cuda"),
|
211 |
+
) -> torch.Tensor:
|
212 |
+
try:
|
213 |
+
tensorized_scalar = torch.tensor(scalars).to(device)
|
214 |
+
output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
|
215 |
+
dist.all_gather(output_tensors, tensorized_scalar)
|
216 |
+
concat = torch.cat(output_tensors, dim=0)
|
217 |
+
|
218 |
+
# truncate the dummy elements added by SequentialDistributedSampler
|
219 |
+
if num_total_examples is not None:
|
220 |
+
concat = concat[:num_total_examples]
|
221 |
+
return concat
|
222 |
+
except AssertionError:
|
223 |
+
raise AssertionError("Not currently using distributed training")
|
224 |
+
|
225 |
+
|
226 |
+
def reissue_pt_warnings(caught_warnings):
|
227 |
+
# Reissue warnings that are not the SAVE_STATE_WARNING
|
228 |
+
if len(caught_warnings) > 1:
|
229 |
+
for w in caught_warnings:
|
230 |
+
if w.category != UserWarning or w.message != SAVE_STATE_WARNING:
|
231 |
+
warnings.warn(w.message, w.category)
|
232 |
+
|
233 |
+
|
234 |
+
@contextmanager
|
235 |
+
def torch_distributed_zero_first(local_rank: int):
|
236 |
+
"""
|
237 |
+
Decorator to make all processes in distributed training wait for each local_master to do something.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
local_rank (`int`): The rank of the local process.
|
241 |
+
"""
|
242 |
+
if local_rank not in [-1, 0]:
|
243 |
+
dist.barrier()
|
244 |
+
yield
|
245 |
+
if local_rank == 0:
|
246 |
+
dist.barrier()
|
247 |
+
|
248 |
+
|
249 |
+
class DistributedSamplerWithLoop(DistributedSampler):
|
250 |
+
"""
|
251 |
+
Like a torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the shuffled
|
252 |
+
samples to make each process have a round multiple of batch_size samples.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
dataset (`torch.utils.data.Dataset`):
|
256 |
+
Dataset used for sampling.
|
257 |
+
batch_size (`int`):
|
258 |
+
The batch size used with this sampler
|
259 |
+
kwargs:
|
260 |
+
All other keyword arguments passed to `DistributedSampler`.
|
261 |
+
"""
|
262 |
+
|
263 |
+
def __init__(self, dataset, batch_size, **kwargs):
|
264 |
+
super().__init__(dataset, **kwargs)
|
265 |
+
self.batch_size = batch_size
|
266 |
+
|
267 |
+
def __iter__(self):
|
268 |
+
indices = list(super().__iter__())
|
269 |
+
remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size
|
270 |
+
# DistributedSampler already added samples from the beginning to make the number of samples a round multiple
|
271 |
+
# of the world size, so we skip those.
|
272 |
+
start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0
|
273 |
+
indices += indices[start_remainder : start_remainder + remainder]
|
274 |
+
return iter(indices)
|
275 |
+
|
276 |
+
|
277 |
+
class SequentialDistributedSampler(Sampler):
|
278 |
+
"""
|
279 |
+
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
|
280 |
+
|
281 |
+
Even though we only use this sampler for eval and predict (no training), which means that the model params won't
|
282 |
+
have to be synced (i.e. will not hang for synchronization even if varied number of forward passes), we still add
|
283 |
+
extra samples to the sampler to make it evenly divisible (like in `DistributedSampler`) to make it easy to `gather`
|
284 |
+
or `reduce` resulting tensors at the end of the loop.
|
285 |
+
"""
|
286 |
+
|
287 |
+
def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
|
288 |
+
warnings.warn(
|
289 |
+
"SequentialDistributedSampler is deprecated and will be removed in v5 of Transformers.",
|
290 |
+
FutureWarning,
|
291 |
+
)
|
292 |
+
if num_replicas is None:
|
293 |
+
if not dist.is_available():
|
294 |
+
raise RuntimeError("Requires distributed package to be available")
|
295 |
+
num_replicas = dist.get_world_size()
|
296 |
+
if rank is None:
|
297 |
+
if not dist.is_available():
|
298 |
+
raise RuntimeError("Requires distributed package to be available")
|
299 |
+
rank = dist.get_rank()
|
300 |
+
self.dataset = dataset
|
301 |
+
self.num_replicas = num_replicas
|
302 |
+
self.rank = rank
|
303 |
+
num_samples = len(self.dataset)
|
304 |
+
# Add extra samples to make num_samples a multiple of batch_size if passed
|
305 |
+
if batch_size is not None:
|
306 |
+
self.num_samples = int(math.ceil(num_samples / (batch_size * num_replicas))) * batch_size
|
307 |
+
else:
|
308 |
+
self.num_samples = int(math.ceil(num_samples / num_replicas))
|
309 |
+
self.total_size = self.num_samples * self.num_replicas
|
310 |
+
self.batch_size = batch_size
|
311 |
+
|
312 |
+
def __iter__(self):
|
313 |
+
indices = list(range(len(self.dataset)))
|
314 |
+
|
315 |
+
# add extra samples to make it evenly divisible
|
316 |
+
indices += indices[: (self.total_size - len(indices))]
|
317 |
+
assert (
|
318 |
+
len(indices) == self.total_size
|
319 |
+
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
|
320 |
+
|
321 |
+
# subsample
|
322 |
+
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
323 |
+
assert (
|
324 |
+
len(indices) == self.num_samples
|
325 |
+
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
326 |
+
|
327 |
+
return iter(indices)
|
328 |
+
|
329 |
+
def __len__(self):
|
330 |
+
return self.num_samples
|
331 |
+
|
332 |
+
|
333 |
+
def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
|
334 |
+
if xm.xrt_world_size() <= 1:
|
335 |
+
return RandomSampler(dataset)
|
336 |
+
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
337 |
+
|
338 |
+
|
339 |
+
def nested_new_like(arrays, num_samples, padding_index=-100):
|
340 |
+
"""Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
|
341 |
+
if isinstance(arrays, (list, tuple)):
|
342 |
+
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
|
343 |
+
return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))
|
344 |
+
|
345 |
+
|
346 |
+
def expand_like(arrays, new_seq_length, padding_index=-100):
|
347 |
+
"""Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
|
348 |
+
result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:])
|
349 |
+
result[:, : arrays.shape[1]] = arrays
|
350 |
+
return result
|
351 |
+
|
352 |
+
|
353 |
+
def nested_truncate(tensors, limit):
|
354 |
+
"Truncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors)."
|
355 |
+
if isinstance(tensors, (list, tuple)):
|
356 |
+
return type(tensors)(nested_truncate(t, limit) for t in tensors)
|
357 |
+
if isinstance(tensors, Mapping):
|
358 |
+
return type(tensors)({k: nested_truncate(t, limit) for k, t in tensors.items()})
|
359 |
+
|
360 |
+
return tensors[:limit]
|
361 |
+
|
362 |
+
|
363 |
+
class DistributedTensorGatherer:
|
364 |
+
"""
|
365 |
+
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
|
366 |
+
|
367 |
+
If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on CPU at every
|
368 |
+
step, our sampler will generate the following indices:
|
369 |
+
|
370 |
+
`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]`
|
371 |
+
|
372 |
+
to get something of size a multiple of 3 (so that each process gets the same dataset length). Then process 0, 1 and
|
373 |
+
2 will be responsible of making predictions for the following samples:
|
374 |
+
|
375 |
+
- P0: `[0, 1, 2, 3, 4, 5]`
|
376 |
+
- P1: `[6, 7, 8, 9, 10, 11]`
|
377 |
+
- P2: `[12, 13, 14, 15, 0, 1]`
|
378 |
+
|
379 |
+
The first batch treated on each process will be
|
380 |
+
|
381 |
+
- P0: `[0, 1]`
|
382 |
+
- P1: `[6, 7]`
|
383 |
+
- P2: `[12, 13]`
|
384 |
+
|
385 |
+
So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor) corresponding to
|
386 |
+
the following indices:
|
387 |
+
|
388 |
+
`[0, 1, 6, 7, 12, 13]`
|
389 |
+
|
390 |
+
If we directly concatenate our results without taking any precautions, the user will then get the predictions for
|
391 |
+
the indices in this order at the end of the prediction loop:
|
392 |
+
|
393 |
+
`[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]`
|
394 |
+
|
395 |
+
For some reason, that's not going to roll their boat. This class is there to solve that problem.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
world_size (`int`):
|
399 |
+
The number of processes used in the distributed training.
|
400 |
+
num_samples (`int`):
|
401 |
+
The number of samples in our dataset.
|
402 |
+
make_multiple_of (`int`, *optional*):
|
403 |
+
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
|
404 |
+
(by adding samples).
|
405 |
+
padding_index (`int`, *optional*, defaults to -100):
|
406 |
+
The padding index to use if the arrays don't all have the same sequence length.
|
407 |
+
"""
|
408 |
+
|
409 |
+
def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
|
410 |
+
warnings.warn(
|
411 |
+
"DistributedTensorGatherer is deprecated and will be removed in v5 of Transformers.",
|
412 |
+
FutureWarning,
|
413 |
+
)
|
414 |
+
self.world_size = world_size
|
415 |
+
self.num_samples = num_samples
|
416 |
+
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
|
417 |
+
self.total_samples = int(np.ceil(num_samples / total_size)) * total_size
|
418 |
+
self.process_length = self.total_samples // world_size
|
419 |
+
self._storage = None
|
420 |
+
self._offsets = None
|
421 |
+
self.padding_index = padding_index
|
422 |
+
|
423 |
+
def add_arrays(self, arrays):
|
424 |
+
"""
|
425 |
+
Add `arrays` to the internal storage, Will initialize the storage to the full size at the first arrays passed
|
426 |
+
so that if we're bound to get an OOM, it happens at the beginning.
|
427 |
+
"""
|
428 |
+
if arrays is None:
|
429 |
+
return
|
430 |
+
if self._storage is None:
|
431 |
+
self._storage = nested_new_like(arrays, self.total_samples, padding_index=self.padding_index)
|
432 |
+
self._offsets = list(range(0, self.total_samples, self.process_length))
|
433 |
+
|
434 |
+
slice_len, self._storage = self._nested_set_tensors(self._storage, arrays)
|
435 |
+
for i in range(self.world_size):
|
436 |
+
self._offsets[i] += slice_len
|
437 |
+
|
438 |
+
def _nested_set_tensors(self, storage, arrays):
|
439 |
+
if isinstance(arrays, (list, tuple)):
|
440 |
+
result = [self._nested_set_tensors(x, y) for x, y in zip(storage, arrays)]
|
441 |
+
return result[0][0], type(arrays)(r[1] for r in result)
|
442 |
+
assert (
|
443 |
+
arrays.shape[0] % self.world_size == 0
|
444 |
+
), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}."
|
445 |
+
|
446 |
+
slice_len = arrays.shape[0] // self.world_size
|
447 |
+
for i in range(self.world_size):
|
448 |
+
if len(arrays.shape) == 1:
|
449 |
+
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
|
450 |
+
else:
|
451 |
+
# Expand the array on the fly if needed.
|
452 |
+
if len(storage.shape) > 1 and storage.shape[1] < arrays.shape[1]:
|
453 |
+
storage = expand_like(storage, arrays.shape[1], padding_index=self.padding_index)
|
454 |
+
storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[
|
455 |
+
i * slice_len : (i + 1) * slice_len
|
456 |
+
]
|
457 |
+
return slice_len, storage
|
458 |
+
|
459 |
+
def finalize(self):
|
460 |
+
"""
|
461 |
+
Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras
|
462 |
+
to get each process a dataset of the same length).
|
463 |
+
"""
|
464 |
+
if self._storage is None:
|
465 |
+
return
|
466 |
+
if self._offsets[0] != self.process_length:
|
467 |
+
logger.warning("Not all data has been set. Are you sure you passed all values?")
|
468 |
+
return nested_truncate(self._storage, self.num_samples)
|
469 |
+
|
470 |
+
|
471 |
+
@dataclass
|
472 |
+
class LabelSmoother:
|
473 |
+
"""
|
474 |
+
Adds label-smoothing on a pre-computed output from a Transformers model.
|
475 |
+
|
476 |
+
Args:
|
477 |
+
epsilon (`float`, *optional*, defaults to 0.1):
|
478 |
+
The label smoothing factor.
|
479 |
+
ignore_index (`int`, *optional*, defaults to -100):
|
480 |
+
The index in the labels to ignore when computing the loss.
|
481 |
+
"""
|
482 |
+
|
483 |
+
epsilon: float = 0.1
|
484 |
+
ignore_index: int = -100
|
485 |
+
|
486 |
+
def __call__(self, model_output, labels, shift_labels=False):
|
487 |
+
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
488 |
+
if shift_labels:
|
489 |
+
logits = logits[..., :-1, :].contiguous()
|
490 |
+
labels = labels[..., 1:].contiguous()
|
491 |
+
|
492 |
+
log_probs = -nn.functional.log_softmax(logits, dim=-1)
|
493 |
+
if labels.dim() == log_probs.dim() - 1:
|
494 |
+
labels = labels.unsqueeze(-1)
|
495 |
+
|
496 |
+
padding_mask = labels.eq(self.ignore_index)
|
497 |
+
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
|
498 |
+
# will ignore them in any case.
|
499 |
+
labels = torch.clamp(labels, min=0)
|
500 |
+
nll_loss = log_probs.gather(dim=-1, index=labels)
|
501 |
+
# works for fp16 input tensor too, by internally upcasting it to fp32
|
502 |
+
smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
|
503 |
+
|
504 |
+
nll_loss.masked_fill_(padding_mask, 0.0)
|
505 |
+
smoothed_loss.masked_fill_(padding_mask, 0.0)
|
506 |
+
|
507 |
+
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
508 |
+
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
|
509 |
+
nll_loss = nll_loss.sum() / num_active_elements
|
510 |
+
smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
|
511 |
+
return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
|
512 |
+
|
513 |
+
|
514 |
+
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
|
515 |
+
"""
|
516 |
+
Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
|
517 |
+
lengths. To do this, the indices are:
|
518 |
+
|
519 |
+
- randomly permuted
|
520 |
+
- grouped in mega-batches of size `mega_batch_mult * batch_size`
|
521 |
+
- sorted by length in each mega-batch
|
522 |
+
|
523 |
+
The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
|
524 |
+
maximum length placed first, so that an OOM happens sooner rather than later.
|
525 |
+
"""
|
526 |
+
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
527 |
+
if mega_batch_mult is None:
|
528 |
+
mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
529 |
+
# Just in case, for tiny datasets
|
530 |
+
if mega_batch_mult == 0:
|
531 |
+
mega_batch_mult = 1
|
532 |
+
|
533 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
534 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
535 |
+
megabatch_size = mega_batch_mult * batch_size
|
536 |
+
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
537 |
+
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
|
538 |
+
|
539 |
+
# The rest is to get the biggest batch first.
|
540 |
+
# Since each megabatch is sorted by descending length, the longest element is the first
|
541 |
+
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
542 |
+
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
|
543 |
+
# Switch to put the longest element in first position
|
544 |
+
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
|
545 |
+
|
546 |
+
return [i for megabatch in megabatches for i in megabatch]
|
547 |
+
|
548 |
+
|
549 |
+
class LengthGroupedSampler(Sampler):
|
550 |
+
r"""
|
551 |
+
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
|
552 |
+
keeping a bit of randomness.
|
553 |
+
"""
|
554 |
+
|
555 |
+
def __init__(
|
556 |
+
self,
|
557 |
+
batch_size: int,
|
558 |
+
dataset: Optional[Dataset] = None,
|
559 |
+
lengths: Optional[List[int]] = None,
|
560 |
+
model_input_name: Optional[str] = None,
|
561 |
+
generator=None,
|
562 |
+
):
|
563 |
+
if dataset is None and lengths is None:
|
564 |
+
raise ValueError("One of dataset and lengths must be provided.")
|
565 |
+
|
566 |
+
self.batch_size = batch_size
|
567 |
+
if lengths is None:
|
568 |
+
model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
569 |
+
if (
|
570 |
+
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
|
571 |
+
or model_input_name not in dataset[0]
|
572 |
+
):
|
573 |
+
raise ValueError(
|
574 |
+
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
575 |
+
f"'{model_input_name}' key."
|
576 |
+
)
|
577 |
+
lengths = [len(feature[model_input_name]) for feature in dataset]
|
578 |
+
elif isinstance(lengths, torch.Tensor):
|
579 |
+
logger.info(
|
580 |
+
"If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..."
|
581 |
+
)
|
582 |
+
lengths = lengths.tolist()
|
583 |
+
|
584 |
+
self.lengths = lengths
|
585 |
+
self.generator = generator
|
586 |
+
|
587 |
+
def __len__(self):
|
588 |
+
return len(self.lengths)
|
589 |
+
|
590 |
+
def __iter__(self):
|
591 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
|
592 |
+
return iter(indices)
|
593 |
+
|
594 |
+
|
595 |
+
class DistributedLengthGroupedSampler(DistributedSampler):
|
596 |
+
r"""
|
597 |
+
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
598 |
+
length while keeping a bit of randomness.
|
599 |
+
"""
|
600 |
+
|
601 |
+
# Copied and adapted from PyTorch DistributedSampler.
|
602 |
+
def __init__(
|
603 |
+
self,
|
604 |
+
batch_size: int,
|
605 |
+
dataset: Optional[Dataset] = None,
|
606 |
+
num_replicas: Optional[int] = None,
|
607 |
+
rank: Optional[int] = None,
|
608 |
+
seed: int = 0,
|
609 |
+
drop_last: bool = False,
|
610 |
+
lengths: Optional[List[int]] = None,
|
611 |
+
model_input_name: Optional[str] = None,
|
612 |
+
):
|
613 |
+
if dataset is None and lengths is None:
|
614 |
+
raise ValueError("One of dataset and lengths must be provided.")
|
615 |
+
if num_replicas is None:
|
616 |
+
if not dist.is_available():
|
617 |
+
raise RuntimeError("Requires distributed package to be available")
|
618 |
+
num_replicas = dist.get_world_size()
|
619 |
+
if rank is None:
|
620 |
+
if not dist.is_available():
|
621 |
+
raise RuntimeError("Requires distributed package to be available")
|
622 |
+
rank = dist.get_rank()
|
623 |
+
|
624 |
+
self.batch_size = batch_size
|
625 |
+
self.num_replicas = num_replicas
|
626 |
+
self.rank = rank
|
627 |
+
self.epoch = 0
|
628 |
+
self.drop_last = drop_last
|
629 |
+
|
630 |
+
if lengths is None:
|
631 |
+
model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
632 |
+
if (
|
633 |
+
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
|
634 |
+
or model_input_name not in dataset[0]
|
635 |
+
):
|
636 |
+
raise ValueError(
|
637 |
+
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
638 |
+
f"'{model_input_name}' key."
|
639 |
+
)
|
640 |
+
lengths = [len(feature[model_input_name]) for feature in dataset]
|
641 |
+
elif isinstance(lengths, torch.Tensor):
|
642 |
+
logger.info(
|
643 |
+
"If lengths is a torch.Tensor, DistributedLengthGroupedSampler will be slow. Converting lengths to"
|
644 |
+
" List[int]..."
|
645 |
+
)
|
646 |
+
lengths = lengths.tolist()
|
647 |
+
|
648 |
+
self.lengths = lengths
|
649 |
+
|
650 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
651 |
+
# is no need to drop any data, since the dataset will be split equally.
|
652 |
+
if self.drop_last and len(self.lengths) % self.num_replicas != 0:
|
653 |
+
# Split to nearest available length that is evenly divisible.
|
654 |
+
# This is to ensure each rank receives the same amount of data when
|
655 |
+
# using this Sampler.
|
656 |
+
self.num_samples = math.ceil((len(self.lengths) - self.num_replicas) / self.num_replicas)
|
657 |
+
else:
|
658 |
+
self.num_samples = math.ceil(len(self.lengths) / self.num_replicas)
|
659 |
+
self.total_size = self.num_samples * self.num_replicas
|
660 |
+
self.seed = seed
|
661 |
+
|
662 |
+
def __iter__(self) -> Iterator:
|
663 |
+
# Deterministically shuffle based on epoch and seed
|
664 |
+
g = torch.Generator()
|
665 |
+
g.manual_seed(self.seed + self.epoch)
|
666 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
667 |
+
|
668 |
+
if not self.drop_last:
|
669 |
+
# add extra samples to make it evenly divisible
|
670 |
+
indices += indices[: (self.total_size - len(indices))]
|
671 |
+
else:
|
672 |
+
# remove tail of data to make it evenly divisible.
|
673 |
+
indices = indices[: self.total_size]
|
674 |
+
assert len(indices) == self.total_size
|
675 |
+
|
676 |
+
# subsample
|
677 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
678 |
+
assert len(indices) == self.num_samples
|
679 |
+
|
680 |
+
return iter(indices)
|
681 |
+
|
682 |
+
|
683 |
+
class ShardSampler(Sampler):
|
684 |
+
"""
|
685 |
+
Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch
|
686 |
+
size 4, the first two batches are `[0, 1, 2, 3, 4, 5, 6, 7]` and `[8, 9, 10, 11, 12, 13, 14, 15]`, which shard into
|
687 |
+
`[0, 1, 2, 3]` and `[8, 9, 10, 11]` for GPU-0 and `[4, 5, 6, 7]` and `[12, 13, 14, 15]` for GPU-1.
|
688 |
+
|
689 |
+
The sampler thus yields `[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and `[4, 5, 6, 7, 12, 13, 14, 15]` on GPU-1.
|
690 |
+
"""
|
691 |
+
|
692 |
+
def __init__(
|
693 |
+
self,
|
694 |
+
dataset: Dataset,
|
695 |
+
batch_size: int = 1,
|
696 |
+
drop_last: bool = False,
|
697 |
+
num_processes: int = 1,
|
698 |
+
process_index: int = 0,
|
699 |
+
):
|
700 |
+
self.dataset = dataset
|
701 |
+
self.batch_size = batch_size
|
702 |
+
self.drop_last = drop_last
|
703 |
+
self.num_processes = num_processes
|
704 |
+
self.process_index = process_index
|
705 |
+
|
706 |
+
self.total_batch_size = total_batch_size = batch_size * num_processes
|
707 |
+
|
708 |
+
num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size)
|
709 |
+
self.total_num_samples = num_batches * total_batch_size
|
710 |
+
|
711 |
+
def __iter__(self):
|
712 |
+
indices = list(range(len(self.dataset)))
|
713 |
+
|
714 |
+
# Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset
|
715 |
+
# and it needs to be done several times.
|
716 |
+
while len(indices) < self.total_num_samples:
|
717 |
+
indices += indices[: (self.total_num_samples - len(indices))]
|
718 |
+
|
719 |
+
result = []
|
720 |
+
for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size):
|
721 |
+
result += indices[batch_start : batch_start + self.batch_size]
|
722 |
+
|
723 |
+
return iter(result)
|
724 |
+
|
725 |
+
def __len__(self):
|
726 |
+
# Each shard only sees a fraction of total_num_samples.
|
727 |
+
return self.total_num_samples // self.num_processes
|
728 |
+
|
729 |
+
|
730 |
+
class IterableDatasetShard(IterableDataset):
|
731 |
+
"""
|
732 |
+
Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
|
733 |
+
always yield a number of samples that is a round multiple of the actual batch size (which is `batch_size x
|
734 |
+
num_processes`). Depending on the value of the `drop_last` attribute, it will either stop the iteration at the
|
735 |
+
first batch that would be too small or loop with indices from the beginning.
|
736 |
+
|
737 |
+
On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch size of
|
738 |
+
2:
|
739 |
+
|
740 |
+
- the shard on process 0 will yield `[0, 1, 4, 5, 8, 9]` so will see batches `[0, 1]`, `[4, 5]`, `[8, 9]`
|
741 |
+
- the shard on process 1 will yield `[2, 3, 6, 7, 10, 11]` so will see batches `[2, 3]`, `[6, 7]`, `[10, 11]`
|
742 |
+
|
743 |
+
<Tip warning={true}>
|
744 |
+
|
745 |
+
If your IterableDataset implements some randomization that needs to be applied the same way on all processes
|
746 |
+
(for instance, a shuffling), you should use a `torch.Generator` in a `generator` attribute of the `dataset` to
|
747 |
+
generate your random numbers and call the [`~trainer_pt_utils.IterableDatasetShard.set_epoch`] method of this
|
748 |
+
object. It will set the seed of this `generator` to `seed + epoch` on all processes before starting the
|
749 |
+
iteration. Alternatively, you can also implement a `set_epoch()` method in your iterable dataset to deal with
|
750 |
+
this.
|
751 |
+
|
752 |
+
</Tip>
|
753 |
+
|
754 |
+
Args:
|
755 |
+
dataset (`torch.utils.data.IterableDataset`):
|
756 |
+
The batch sampler to split in several shards.
|
757 |
+
batch_size (`int`, *optional*, defaults to 1):
|
758 |
+
The size of the batches per shard.
|
759 |
+
drop_last (`bool`, *optional*, defaults to `False`):
|
760 |
+
Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
|
761 |
+
beginning.
|
762 |
+
num_processes (`int`, *optional*, defaults to 1):
|
763 |
+
The number of processes running concurrently.
|
764 |
+
process_index (`int`, *optional*, defaults to 0):
|
765 |
+
The index of the current process.
|
766 |
+
seed (`int`, *optional*, defaults to 0):
|
767 |
+
A random seed that will be used for the random number generation in
|
768 |
+
[`~trainer_pt_utils.IterableDatasetShard.set_epoch`].
|
769 |
+
"""
|
770 |
+
|
771 |
+
def __init__(
|
772 |
+
self,
|
773 |
+
dataset: IterableDataset,
|
774 |
+
batch_size: int = 1,
|
775 |
+
drop_last: bool = False,
|
776 |
+
num_processes: int = 1,
|
777 |
+
process_index: int = 0,
|
778 |
+
seed: int = 0,
|
779 |
+
):
|
780 |
+
self.dataset = dataset
|
781 |
+
self.batch_size = batch_size
|
782 |
+
self.drop_last = drop_last
|
783 |
+
self.num_processes = num_processes
|
784 |
+
self.process_index = process_index
|
785 |
+
self.seed = seed
|
786 |
+
self.epoch = 0
|
787 |
+
self.num_examples = 0
|
788 |
+
|
789 |
+
def set_epoch(self, epoch):
|
790 |
+
self.epoch = epoch
|
791 |
+
if hasattr(self.dataset, "set_epoch"):
|
792 |
+
self.dataset.set_epoch(epoch)
|
793 |
+
|
794 |
+
def __iter__(self):
|
795 |
+
self.num_examples = 0
|
796 |
+
if (
|
797 |
+
not hasattr(self.dataset, "set_epoch")
|
798 |
+
and hasattr(self.dataset, "generator")
|
799 |
+
and isinstance(self.dataset.generator, torch.Generator)
|
800 |
+
):
|
801 |
+
self.dataset.generator.manual_seed(self.seed + self.epoch)
|
802 |
+
real_batch_size = self.batch_size * self.num_processes
|
803 |
+
process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)
|
804 |
+
|
805 |
+
first_batch = None
|
806 |
+
current_batch = []
|
807 |
+
for element in self.dataset:
|
808 |
+
self.num_examples += 1
|
809 |
+
current_batch.append(element)
|
810 |
+
# Wait to have a full batch before yielding elements.
|
811 |
+
if len(current_batch) == real_batch_size:
|
812 |
+
for i in process_slice:
|
813 |
+
yield current_batch[i]
|
814 |
+
if first_batch is None:
|
815 |
+
first_batch = current_batch.copy()
|
816 |
+
current_batch = []
|
817 |
+
|
818 |
+
# Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
|
819 |
+
if not self.drop_last and len(current_batch) > 0:
|
820 |
+
if first_batch is None:
|
821 |
+
first_batch = current_batch.copy()
|
822 |
+
while len(current_batch) < real_batch_size:
|
823 |
+
current_batch += first_batch
|
824 |
+
for i in process_slice:
|
825 |
+
yield current_batch[i]
|
826 |
+
|
827 |
+
def __len__(self):
|
828 |
+
# Will raise an error if the underlying dataset is not sized.
|
829 |
+
if self.drop_last:
|
830 |
+
return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
|
831 |
+
else:
|
832 |
+
return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
|
833 |
+
|
834 |
+
|
835 |
+
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
|
836 |
+
# helper methods here
|
837 |
+
|
838 |
+
|
839 |
+
def _get_learning_rate(self):
|
840 |
+
if self.deepspeed:
|
841 |
+
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
|
842 |
+
# not run for the first few dozen steps while loss scale is too large, and thus during
|
843 |
+
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
|
844 |
+
try:
|
845 |
+
last_lr = self.lr_scheduler.get_last_lr()[0]
|
846 |
+
except AssertionError as e:
|
847 |
+
if "need to call step" in str(e):
|
848 |
+
logger.warning("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
|
849 |
+
last_lr = 0
|
850 |
+
else:
|
851 |
+
raise
|
852 |
+
else:
|
853 |
+
last_lr = self.lr_scheduler.get_last_lr()[0]
|
854 |
+
if torch.is_tensor(last_lr):
|
855 |
+
last_lr = last_lr.item()
|
856 |
+
return last_lr
|
857 |
+
|
858 |
+
|
859 |
+
def _secs2timedelta(secs):
|
860 |
+
"""
|
861 |
+
convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimals
|
862 |
+
"""
|
863 |
+
|
864 |
+
msec = int(abs(secs - int(secs)) * 100)
|
865 |
+
return f"{datetime.timedelta(seconds=int(secs))}.{msec:02d}"
|
866 |
+
|
867 |
+
|
868 |
+
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
|
869 |
+
"""
|
870 |
+
Reformat Trainer metrics values to a human-readable format
|
871 |
+
|
872 |
+
Args:
|
873 |
+
metrics (`Dict[str, float]`):
|
874 |
+
The metrics returned from train/evaluate/predict
|
875 |
+
|
876 |
+
Returns:
|
877 |
+
metrics (`Dict[str, float]`): The reformatted metrics
|
878 |
+
"""
|
879 |
+
|
880 |
+
metrics_copy = metrics.copy()
|
881 |
+
for k, v in metrics_copy.items():
|
882 |
+
if "_mem_" in k:
|
883 |
+
metrics_copy[k] = f"{ v >> 20 }MB"
|
884 |
+
elif "_runtime" in k:
|
885 |
+
metrics_copy[k] = _secs2timedelta(v)
|
886 |
+
elif k == "total_flos":
|
887 |
+
metrics_copy[k] = f"{ int(v) >> 30 }GF"
|
888 |
+
elif type(metrics_copy[k]) == float:
|
889 |
+
metrics_copy[k] = round(v, 4)
|
890 |
+
|
891 |
+
return metrics_copy
|
892 |
+
|
893 |
+
|
894 |
+
def log_metrics(self, split, metrics):
|
895 |
+
"""
|
896 |
+
Log metrics in a specially formatted way
|
897 |
+
|
898 |
+
Under distributed environment this is done only for a process with rank 0.
|
899 |
+
|
900 |
+
Args:
|
901 |
+
split (`str`):
|
902 |
+
Mode/split name: one of `train`, `eval`, `test`
|
903 |
+
metrics (`Dict[str, float]`):
|
904 |
+
The metrics returned from train/evaluate/predictmetrics: metrics dict
|
905 |
+
|
906 |
+
Notes on memory reports:
|
907 |
+
|
908 |
+
In order to get memory usage report you need to install `psutil`. You can do that with `pip install psutil`.
|
909 |
+
|
910 |
+
Now when this method is run, you will see a report that will include: :
|
911 |
+
|
912 |
+
```
|
913 |
+
init_mem_cpu_alloc_delta = 1301MB
|
914 |
+
init_mem_cpu_peaked_delta = 154MB
|
915 |
+
init_mem_gpu_alloc_delta = 230MB
|
916 |
+
init_mem_gpu_peaked_delta = 0MB
|
917 |
+
train_mem_cpu_alloc_delta = 1345MB
|
918 |
+
train_mem_cpu_peaked_delta = 0MB
|
919 |
+
train_mem_gpu_alloc_delta = 693MB
|
920 |
+
train_mem_gpu_peaked_delta = 7MB
|
921 |
+
```
|
922 |
+
|
923 |
+
**Understanding the reports:**
|
924 |
+
|
925 |
+
- the first segment, e.g., `train__`, tells you which stage the metrics are for. Reports starting with `init_`
|
926 |
+
will be added to the first stage that gets run. So that if only evaluation is run, the memory usage for the
|
927 |
+
`__init__` will be reported along with the `eval_` metrics.
|
928 |
+
- the third segment, is either `cpu` or `gpu`, tells you whether it's the general RAM or the gpu0 memory
|
929 |
+
metric.
|
930 |
+
- `*_alloc_delta` - is the difference in the used/allocated memory counter between the end and the start of the
|
931 |
+
stage - it can be negative if a function released more memory than it allocated.
|
932 |
+
- `*_peaked_delta` - is any extra memory that was consumed and then freed - relative to the current allocated
|
933 |
+
memory counter - it is never negative. When you look at the metrics of any stage you add up `alloc_delta` +
|
934 |
+
`peaked_delta` and you know how much memory was needed to complete that stage.
|
935 |
+
|
936 |
+
The reporting happens only for process of rank 0 and gpu 0 (if there is a gpu). Typically this is enough since the
|
937 |
+
main process does the bulk of work, but it could be not quite so if model parallel is used and then other GPUs may
|
938 |
+
use a different amount of gpu memory. This is also not the same under DataParallel where gpu0 may require much more
|
939 |
+
memory than the rest since it stores the gradient and optimizer states for all participating GPUS. Perhaps in the
|
940 |
+
future these reports will evolve to measure those too.
|
941 |
+
|
942 |
+
The CPU RAM metric measures RSS (Resident Set Size) includes both the memory which is unique to the process and the
|
943 |
+
memory shared with other processes. It is important to note that it does not include swapped out memory, so the
|
944 |
+
reports could be imprecise.
|
945 |
+
|
946 |
+
The CPU peak memory is measured using a sampling thread. Due to python's GIL it may miss some of the peak memory if
|
947 |
+
that thread didn't get a chance to run when the highest memory was used. Therefore this report can be less than
|
948 |
+
reality. Using `tracemalloc` would have reported the exact peak memory, but it doesn't report memory allocations
|
949 |
+
outside of python. So if some C++ CUDA extension allocated its own memory it won't be reported. And therefore it
|
950 |
+
was dropped in favor of the memory sampling approach, which reads the current process memory usage.
|
951 |
+
|
952 |
+
The GPU allocated and peak memory reporting is done with `torch.cuda.memory_allocated()` and
|
953 |
+
`torch.cuda.max_memory_allocated()`. This metric reports only "deltas" for pytorch-specific allocations, as
|
954 |
+
`torch.cuda` memory management system doesn't track any memory allocated outside of pytorch. For example, the very
|
955 |
+
first cuda call typically loads CUDA kernels, which may take from 0.5 to 2GB of GPU memory.
|
956 |
+
|
957 |
+
Note that this tracker doesn't account for memory allocations outside of [`Trainer`]'s `__init__`, `train`,
|
958 |
+
`evaluate` and `predict` calls.
|
959 |
+
|
960 |
+
Because `evaluation` calls may happen during `train`, we can't handle nested invocations because
|
961 |
+
`torch.cuda.max_memory_allocated` is a single counter, so if it gets reset by a nested eval call, `train`'s tracker
|
962 |
+
will report incorrect info. If this [pytorch issue](https://github.com/pytorch/pytorch/issues/16266) gets resolved
|
963 |
+
it will be possible to change this class to be re-entrant. Until then we will only track the outer level of
|
964 |
+
`train`, `evaluate` and `predict` methods. Which means that if `eval` is called during `train`, it's the latter
|
965 |
+
that will account for its memory usage and that of the former.
|
966 |
+
|
967 |
+
This also means that if any other tool that is used along the [`Trainer`] calls
|
968 |
+
`torch.cuda.reset_peak_memory_stats`, the gpu peak memory stats could be invalid. And the [`Trainer`] will disrupt
|
969 |
+
the normal behavior of any such tools that rely on calling `torch.cuda.reset_peak_memory_stats` themselves.
|
970 |
+
|
971 |
+
For best performance you may want to consider turning the memory profiling off for production runs.
|
972 |
+
"""
|
973 |
+
if not self.is_world_process_zero():
|
974 |
+
return
|
975 |
+
|
976 |
+
print(f"***** {split} metrics *****")
|
977 |
+
metrics_formatted = self.metrics_format(metrics)
|
978 |
+
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
979 |
+
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
980 |
+
for key in sorted(metrics_formatted.keys()):
|
981 |
+
print(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
982 |
+
|
983 |
+
|
984 |
+
def save_metrics(self, split, metrics, combined=True):
|
985 |
+
"""
|
986 |
+
Save metrics into a json file for that split, e.g. `train_results.json`.
|
987 |
+
|
988 |
+
Under distributed environment this is done only for a process with rank 0.
|
989 |
+
|
990 |
+
Args:
|
991 |
+
split (`str`):
|
992 |
+
Mode/split name: one of `train`, `eval`, `test`, `all`
|
993 |
+
metrics (`Dict[str, float]`):
|
994 |
+
The metrics returned from train/evaluate/predict
|
995 |
+
combined (`bool`, *optional*, defaults to `True`):
|
996 |
+
Creates combined metrics by updating `all_results.json` with metrics of this call
|
997 |
+
|
998 |
+
To understand the metrics please read the docstring of [`~Trainer.log_metrics`]. The only difference is that raw
|
999 |
+
unformatted numbers are saved in the current method.
|
1000 |
+
|
1001 |
+
"""
|
1002 |
+
if not self.is_world_process_zero():
|
1003 |
+
return
|
1004 |
+
|
1005 |
+
path = os.path.join(self.args.output_dir, f"{split}_results.json")
|
1006 |
+
with open(path, "w") as f:
|
1007 |
+
json.dump(metrics, f, indent=4, sort_keys=True)
|
1008 |
+
|
1009 |
+
if combined:
|
1010 |
+
path = os.path.join(self.args.output_dir, "all_results.json")
|
1011 |
+
if os.path.exists(path):
|
1012 |
+
with open(path, "r") as f:
|
1013 |
+
all_metrics = json.load(f)
|
1014 |
+
else:
|
1015 |
+
all_metrics = {}
|
1016 |
+
|
1017 |
+
all_metrics.update(metrics)
|
1018 |
+
with open(path, "w") as f:
|
1019 |
+
json.dump(all_metrics, f, indent=4, sort_keys=True)
|
1020 |
+
|
1021 |
+
|
1022 |
+
def save_state(self):
|
1023 |
+
"""
|
1024 |
+
Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model
|
1025 |
+
|
1026 |
+
Under distributed environment this is done only for a process with rank 0.
|
1027 |
+
"""
|
1028 |
+
if not self.is_world_process_zero():
|
1029 |
+
return
|
1030 |
+
|
1031 |
+
path = os.path.join(self.args.output_dir, "trainer_state.json")
|
1032 |
+
self.state.save_to_json(path)
|
1033 |
+
|
1034 |
+
|
1035 |
+
def get_parameter_names(model, forbidden_layer_types):
|
1036 |
+
"""
|
1037 |
+
Returns the names of the model parameters that are not inside a forbidden layer.
|
1038 |
+
"""
|
1039 |
+
result = []
|
1040 |
+
for name, child in model.named_children():
|
1041 |
+
result += [
|
1042 |
+
f"{name}.{n}"
|
1043 |
+
for n in get_parameter_names(child, forbidden_layer_types)
|
1044 |
+
if not isinstance(child, tuple(forbidden_layer_types))
|
1045 |
+
]
|
1046 |
+
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
|
1047 |
+
result += list(model._parameters.keys())
|
1048 |
+
return result
|
1049 |
+
|
1050 |
+
|
1051 |
+
def get_module_class_from_name(module, name):
|
1052 |
+
"""
|
1053 |
+
Gets a class from a module by its name.
|
1054 |
+
|
1055 |
+
Args:
|
1056 |
+
module (`torch.nn.Module`): The module to get the class from.
|
1057 |
+
name (`str`): The name of the class.
|
1058 |
+
"""
|
1059 |
+
modules_children = list(module.children())
|
1060 |
+
if module.__class__.__name__ == name:
|
1061 |
+
return module.__class__
|
1062 |
+
elif len(modules_children) == 0:
|
1063 |
+
return
|
1064 |
+
else:
|
1065 |
+
for child_module in modules_children:
|
1066 |
+
module_class = get_module_class_from_name(child_module, name)
|
1067 |
+
if module_class is not None:
|
1068 |
+
return module_class
|
1069 |
+
|
1070 |
+
|
1071 |
+
if is_sagemaker_mp_enabled():
|
1072 |
+
import smdistributed.modelparallel.torch as smp
|
1073 |
+
|
1074 |
+
@smp.step()
|
1075 |
+
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
|
1076 |
+
outputs = model(**inputs)
|
1077 |
+
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
1078 |
+
loss /= gradient_accumulation_steps
|
1079 |
+
model.backward(loss)
|
1080 |
+
return loss
|
1081 |
+
|
1082 |
+
@smp.step()
|
1083 |
+
def smp_forward_only(model, inputs):
|
1084 |
+
return model(**inputs)
|
1085 |
+
|
1086 |
+
def smp_gather(tensor):
|
1087 |
+
if isinstance(tensor, (list, tuple)):
|
1088 |
+
return type(tensor)(smp_gather(t) for t in tensor)
|
1089 |
+
elif isinstance(tensor, dict):
|
1090 |
+
return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
|
1091 |
+
elif not isinstance(tensor, torch.Tensor):
|
1092 |
+
raise TypeError(
|
1093 |
+
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
1094 |
+
)
|
1095 |
+
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
|
1096 |
+
all_tensors = [atleast_1d(t) for t in all_tensors]
|
1097 |
+
return torch.cat([t.cpu() for t in all_tensors], dim=0)
|
1098 |
+
|
1099 |
+
def smp_nested_concat(tensor):
|
1100 |
+
if isinstance(tensor, (list, tuple)):
|
1101 |
+
return type(tensor)(smp_nested_concat(t) for t in tensor)
|
1102 |
+
elif isinstance(tensor, dict):
|
1103 |
+
return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()})
|
1104 |
+
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
|
1105 |
+
# which is also the name of the decorator so Python is confused.
|
1106 |
+
return tensor.concat().detach().cpu()
|