Cletrason commited on
Commit
c57fa4c
1 Parent(s): ba0418b

Create trainer_pt_utils.py

Browse files
Files changed (1) hide show
  1. trainer_pt_utils.py +1106 -0
trainer_pt_utils.py ADDED
@@ -0,0 +1,1106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020-present the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Torch utilities for the Trainer class.
17
+ """
18
+
19
+ import datetime
20
+ import json
21
+ import math
22
+ import os
23
+ import sys
24
+ import warnings
25
+ from collections.abc import Mapping
26
+ from contextlib import contextmanager
27
+ from dataclasses import dataclass
28
+ from logging import StreamHandler
29
+ from typing import Any, Dict, Iterator, List, Optional, Union
30
+
31
+ import numpy as np
32
+ import torch
33
+ import torch.distributed as dist
34
+ from torch import nn
35
+ from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
36
+ from torch.utils.data.distributed import DistributedSampler
37
+
38
+ from .tokenization_utils_base import BatchEncoding
39
+ from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging
40
+
41
+
42
+ if is_training_run_on_sagemaker():
43
+ logging.add_handler(StreamHandler(sys.stdout))
44
+
45
+ if is_torch_tpu_available(check_device=False):
46
+ import torch_xla.core.xla_model as xm
47
+
48
+ # this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
49
+ try:
50
+ from torch.optim.lr_scheduler import SAVE_STATE_WARNING
51
+ except ImportError:
52
+ SAVE_STATE_WARNING = ""
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+
57
+ def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
58
+ if isinstance(tensor_or_array, torch.Tensor):
59
+ if hasattr(torch, "atleast_1d"):
60
+ tensor_or_array = torch.atleast_1d(tensor_or_array)
61
+ elif tensor_or_array.ndim < 1:
62
+ tensor_or_array = tensor_or_array[None]
63
+ else:
64
+ tensor_or_array = np.atleast_1d(tensor_or_array)
65
+ return tensor_or_array
66
+
67
+
68
+ def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
69
+ """Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
70
+ tensor1 = atleast_1d(tensor1)
71
+ tensor2 = atleast_1d(tensor2)
72
+
73
+ if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
74
+ return torch.cat((tensor1, tensor2), dim=0)
75
+
76
+ # Let's figure out the new shape
77
+ new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:]
78
+
79
+ # Now let's fill the result tensor
80
+ result = tensor1.new_full(new_shape, padding_index)
81
+ result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1
82
+ result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2
83
+ return result
84
+
85
+
86
+ def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
87
+ """Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
88
+ array1 = atleast_1d(array1)
89
+ array2 = atleast_1d(array2)
90
+
91
+ if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
92
+ return np.concatenate((array1, array2), axis=0)
93
+
94
+ # Let's figure out the new shape
95
+ new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:]
96
+
97
+ # Now let's fill the result tensor
98
+ result = np.full_like(array1, padding_index, shape=new_shape)
99
+ result[: array1.shape[0], : array1.shape[1]] = array1
100
+ result[array1.shape[0] :, : array2.shape[1]] = array2
101
+ return result
102
+
103
+
104
+ def nested_concat(tensors, new_tensors, padding_index=-100):
105
+ """
106
+ Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
107
+ nested list/tuples/dict of tensors.
108
+ """
109
+ assert type(tensors) == type(
110
+ new_tensors
111
+ ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
112
+ if isinstance(tensors, (list, tuple)):
113
+ return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
114
+ elif isinstance(tensors, torch.Tensor):
115
+ return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
116
+ elif isinstance(tensors, Mapping):
117
+ return type(tensors)(
118
+ {k: nested_concat(t, new_tensors[k], padding_index=padding_index) for k, t in tensors.items()}
119
+ )
120
+ elif isinstance(tensors, np.ndarray):
121
+ return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
122
+ else:
123
+ raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
124
+
125
+
126
+ def find_batch_size(tensors):
127
+ """
128
+ Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
129
+ """
130
+ if isinstance(tensors, (list, tuple)):
131
+ for t in tensors:
132
+ result = find_batch_size(t)
133
+ if result is not None:
134
+ return result
135
+ elif isinstance(tensors, Mapping):
136
+ for key, value in tensors.items():
137
+ result = find_batch_size(value)
138
+ if result is not None:
139
+ return result
140
+ elif isinstance(tensors, torch.Tensor):
141
+ return tensors.shape[0] if len(tensors.shape) >= 1 else None
142
+ elif isinstance(tensors, np.ndarray):
143
+ return tensors.shape[0] if len(tensors.shape) >= 1 else None
144
+
145
+
146
+ def nested_numpify(tensors):
147
+ "Numpify `tensors` (even if it's a nested list/tuple/dict of tensors)."
148
+ if isinstance(tensors, (list, tuple)):
149
+ return type(tensors)(nested_numpify(t) for t in tensors)
150
+ if isinstance(tensors, Mapping):
151
+ return type(tensors)({k: nested_numpify(t) for k, t in tensors.items()})
152
+
153
+ t = tensors.cpu()
154
+ if t.dtype == torch.bfloat16:
155
+ # As of Numpy 1.21.4, NumPy does not support bfloat16 (see
156
+ # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
157
+ # Until Numpy adds bfloat16, we must convert float32.
158
+ t = t.to(torch.float32)
159
+ return t.numpy()
160
+
161
+
162
+ def nested_detach(tensors):
163
+ "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
164
+ if isinstance(tensors, (list, tuple)):
165
+ return type(tensors)(nested_detach(t) for t in tensors)
166
+ elif isinstance(tensors, Mapping):
167
+ return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
168
+ return tensors.detach()
169
+
170
+
171
+ def nested_xla_mesh_reduce(tensors, name):
172
+ if is_torch_tpu_available():
173
+ import torch_xla.core.xla_model as xm
174
+
175
+ if isinstance(tensors, (list, tuple)):
176
+ return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
177
+ if isinstance(tensors, Mapping):
178
+ return type(tensors)(
179
+ {k: nested_xla_mesh_reduce(t, f"{name}_{i}") for i, (k, t) in enumerate(tensors.items())}
180
+ )
181
+
182
+ tensors = atleast_1d(tensors)
183
+ return xm.mesh_reduce(name, tensors, torch.cat)
184
+ else:
185
+ raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
186
+
187
+
188
+ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> Any:
189
+ try:
190
+ if isinstance(tensor, (tuple, list)):
191
+ return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
192
+ if isinstance(tensor, Mapping):
193
+ return type(tensor)({k: distributed_concat(t, num_total_examples) for k, t in tensor.items()})
194
+ tensor = atleast_1d(tensor).contiguous()
195
+ output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
196
+ dist.all_gather(output_tensors, tensor)
197
+ concat = torch.cat(output_tensors, dim=0)
198
+
199
+ # truncate the dummy elements added by SequentialDistributedSampler
200
+ if num_total_examples is not None:
201
+ concat = concat[:num_total_examples]
202
+ return concat
203
+ except AssertionError:
204
+ raise AssertionError("Not currently using distributed training")
205
+
206
+
207
+ def distributed_broadcast_scalars(
208
+ scalars: List[Union[int, float]],
209
+ num_total_examples: Optional[int] = None,
210
+ device: Optional[torch.device] = torch.device("cuda"),
211
+ ) -> torch.Tensor:
212
+ try:
213
+ tensorized_scalar = torch.tensor(scalars).to(device)
214
+ output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
215
+ dist.all_gather(output_tensors, tensorized_scalar)
216
+ concat = torch.cat(output_tensors, dim=0)
217
+
218
+ # truncate the dummy elements added by SequentialDistributedSampler
219
+ if num_total_examples is not None:
220
+ concat = concat[:num_total_examples]
221
+ return concat
222
+ except AssertionError:
223
+ raise AssertionError("Not currently using distributed training")
224
+
225
+
226
+ def reissue_pt_warnings(caught_warnings):
227
+ # Reissue warnings that are not the SAVE_STATE_WARNING
228
+ if len(caught_warnings) > 1:
229
+ for w in caught_warnings:
230
+ if w.category != UserWarning or w.message != SAVE_STATE_WARNING:
231
+ warnings.warn(w.message, w.category)
232
+
233
+
234
+ @contextmanager
235
+ def torch_distributed_zero_first(local_rank: int):
236
+ """
237
+ Decorator to make all processes in distributed training wait for each local_master to do something.
238
+
239
+ Args:
240
+ local_rank (`int`): The rank of the local process.
241
+ """
242
+ if local_rank not in [-1, 0]:
243
+ dist.barrier()
244
+ yield
245
+ if local_rank == 0:
246
+ dist.barrier()
247
+
248
+
249
+ class DistributedSamplerWithLoop(DistributedSampler):
250
+ """
251
+ Like a torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the shuffled
252
+ samples to make each process have a round multiple of batch_size samples.
253
+
254
+ Args:
255
+ dataset (`torch.utils.data.Dataset`):
256
+ Dataset used for sampling.
257
+ batch_size (`int`):
258
+ The batch size used with this sampler
259
+ kwargs:
260
+ All other keyword arguments passed to `DistributedSampler`.
261
+ """
262
+
263
+ def __init__(self, dataset, batch_size, **kwargs):
264
+ super().__init__(dataset, **kwargs)
265
+ self.batch_size = batch_size
266
+
267
+ def __iter__(self):
268
+ indices = list(super().__iter__())
269
+ remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size
270
+ # DistributedSampler already added samples from the beginning to make the number of samples a round multiple
271
+ # of the world size, so we skip those.
272
+ start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0
273
+ indices += indices[start_remainder : start_remainder + remainder]
274
+ return iter(indices)
275
+
276
+
277
+ class SequentialDistributedSampler(Sampler):
278
+ """
279
+ Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
280
+
281
+ Even though we only use this sampler for eval and predict (no training), which means that the model params won't
282
+ have to be synced (i.e. will not hang for synchronization even if varied number of forward passes), we still add
283
+ extra samples to the sampler to make it evenly divisible (like in `DistributedSampler`) to make it easy to `gather`
284
+ or `reduce` resulting tensors at the end of the loop.
285
+ """
286
+
287
+ def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
288
+ warnings.warn(
289
+ "SequentialDistributedSampler is deprecated and will be removed in v5 of Transformers.",
290
+ FutureWarning,
291
+ )
292
+ if num_replicas is None:
293
+ if not dist.is_available():
294
+ raise RuntimeError("Requires distributed package to be available")
295
+ num_replicas = dist.get_world_size()
296
+ if rank is None:
297
+ if not dist.is_available():
298
+ raise RuntimeError("Requires distributed package to be available")
299
+ rank = dist.get_rank()
300
+ self.dataset = dataset
301
+ self.num_replicas = num_replicas
302
+ self.rank = rank
303
+ num_samples = len(self.dataset)
304
+ # Add extra samples to make num_samples a multiple of batch_size if passed
305
+ if batch_size is not None:
306
+ self.num_samples = int(math.ceil(num_samples / (batch_size * num_replicas))) * batch_size
307
+ else:
308
+ self.num_samples = int(math.ceil(num_samples / num_replicas))
309
+ self.total_size = self.num_samples * self.num_replicas
310
+ self.batch_size = batch_size
311
+
312
+ def __iter__(self):
313
+ indices = list(range(len(self.dataset)))
314
+
315
+ # add extra samples to make it evenly divisible
316
+ indices += indices[: (self.total_size - len(indices))]
317
+ assert (
318
+ len(indices) == self.total_size
319
+ ), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
320
+
321
+ # subsample
322
+ indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
323
+ assert (
324
+ len(indices) == self.num_samples
325
+ ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
326
+
327
+ return iter(indices)
328
+
329
+ def __len__(self):
330
+ return self.num_samples
331
+
332
+
333
+ def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
334
+ if xm.xrt_world_size() <= 1:
335
+ return RandomSampler(dataset)
336
+ return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
337
+
338
+
339
+ def nested_new_like(arrays, num_samples, padding_index=-100):
340
+ """Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
341
+ if isinstance(arrays, (list, tuple)):
342
+ return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
343
+ return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))
344
+
345
+
346
+ def expand_like(arrays, new_seq_length, padding_index=-100):
347
+ """Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
348
+ result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:])
349
+ result[:, : arrays.shape[1]] = arrays
350
+ return result
351
+
352
+
353
+ def nested_truncate(tensors, limit):
354
+ "Truncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors)."
355
+ if isinstance(tensors, (list, tuple)):
356
+ return type(tensors)(nested_truncate(t, limit) for t in tensors)
357
+ if isinstance(tensors, Mapping):
358
+ return type(tensors)({k: nested_truncate(t, limit) for k, t in tensors.items()})
359
+
360
+ return tensors[:limit]
361
+
362
+
363
+ class DistributedTensorGatherer:
364
+ """
365
+ A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
366
+
367
+ If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on CPU at every
368
+ step, our sampler will generate the following indices:
369
+
370
+ `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]`
371
+
372
+ to get something of size a multiple of 3 (so that each process gets the same dataset length). Then process 0, 1 and
373
+ 2 will be responsible of making predictions for the following samples:
374
+
375
+ - P0: `[0, 1, 2, 3, 4, 5]`
376
+ - P1: `[6, 7, 8, 9, 10, 11]`
377
+ - P2: `[12, 13, 14, 15, 0, 1]`
378
+
379
+ The first batch treated on each process will be
380
+
381
+ - P0: `[0, 1]`
382
+ - P1: `[6, 7]`
383
+ - P2: `[12, 13]`
384
+
385
+ So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor) corresponding to
386
+ the following indices:
387
+
388
+ `[0, 1, 6, 7, 12, 13]`
389
+
390
+ If we directly concatenate our results without taking any precautions, the user will then get the predictions for
391
+ the indices in this order at the end of the prediction loop:
392
+
393
+ `[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]`
394
+
395
+ For some reason, that's not going to roll their boat. This class is there to solve that problem.
396
+
397
+ Args:
398
+ world_size (`int`):
399
+ The number of processes used in the distributed training.
400
+ num_samples (`int`):
401
+ The number of samples in our dataset.
402
+ make_multiple_of (`int`, *optional*):
403
+ If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
404
+ (by adding samples).
405
+ padding_index (`int`, *optional*, defaults to -100):
406
+ The padding index to use if the arrays don't all have the same sequence length.
407
+ """
408
+
409
+ def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
410
+ warnings.warn(
411
+ "DistributedTensorGatherer is deprecated and will be removed in v5 of Transformers.",
412
+ FutureWarning,
413
+ )
414
+ self.world_size = world_size
415
+ self.num_samples = num_samples
416
+ total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
417
+ self.total_samples = int(np.ceil(num_samples / total_size)) * total_size
418
+ self.process_length = self.total_samples // world_size
419
+ self._storage = None
420
+ self._offsets = None
421
+ self.padding_index = padding_index
422
+
423
+ def add_arrays(self, arrays):
424
+ """
425
+ Add `arrays` to the internal storage, Will initialize the storage to the full size at the first arrays passed
426
+ so that if we're bound to get an OOM, it happens at the beginning.
427
+ """
428
+ if arrays is None:
429
+ return
430
+ if self._storage is None:
431
+ self._storage = nested_new_like(arrays, self.total_samples, padding_index=self.padding_index)
432
+ self._offsets = list(range(0, self.total_samples, self.process_length))
433
+
434
+ slice_len, self._storage = self._nested_set_tensors(self._storage, arrays)
435
+ for i in range(self.world_size):
436
+ self._offsets[i] += slice_len
437
+
438
+ def _nested_set_tensors(self, storage, arrays):
439
+ if isinstance(arrays, (list, tuple)):
440
+ result = [self._nested_set_tensors(x, y) for x, y in zip(storage, arrays)]
441
+ return result[0][0], type(arrays)(r[1] for r in result)
442
+ assert (
443
+ arrays.shape[0] % self.world_size == 0
444
+ ), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}."
445
+
446
+ slice_len = arrays.shape[0] // self.world_size
447
+ for i in range(self.world_size):
448
+ if len(arrays.shape) == 1:
449
+ storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
450
+ else:
451
+ # Expand the array on the fly if needed.
452
+ if len(storage.shape) > 1 and storage.shape[1] < arrays.shape[1]:
453
+ storage = expand_like(storage, arrays.shape[1], padding_index=self.padding_index)
454
+ storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[
455
+ i * slice_len : (i + 1) * slice_len
456
+ ]
457
+ return slice_len, storage
458
+
459
+ def finalize(self):
460
+ """
461
+ Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras
462
+ to get each process a dataset of the same length).
463
+ """
464
+ if self._storage is None:
465
+ return
466
+ if self._offsets[0] != self.process_length:
467
+ logger.warning("Not all data has been set. Are you sure you passed all values?")
468
+ return nested_truncate(self._storage, self.num_samples)
469
+
470
+
471
+ @dataclass
472
+ class LabelSmoother:
473
+ """
474
+ Adds label-smoothing on a pre-computed output from a Transformers model.
475
+
476
+ Args:
477
+ epsilon (`float`, *optional*, defaults to 0.1):
478
+ The label smoothing factor.
479
+ ignore_index (`int`, *optional*, defaults to -100):
480
+ The index in the labels to ignore when computing the loss.
481
+ """
482
+
483
+ epsilon: float = 0.1
484
+ ignore_index: int = -100
485
+
486
+ def __call__(self, model_output, labels, shift_labels=False):
487
+ logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
488
+ if shift_labels:
489
+ logits = logits[..., :-1, :].contiguous()
490
+ labels = labels[..., 1:].contiguous()
491
+
492
+ log_probs = -nn.functional.log_softmax(logits, dim=-1)
493
+ if labels.dim() == log_probs.dim() - 1:
494
+ labels = labels.unsqueeze(-1)
495
+
496
+ padding_mask = labels.eq(self.ignore_index)
497
+ # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
498
+ # will ignore them in any case.
499
+ labels = torch.clamp(labels, min=0)
500
+ nll_loss = log_probs.gather(dim=-1, index=labels)
501
+ # works for fp16 input tensor too, by internally upcasting it to fp32
502
+ smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
503
+
504
+ nll_loss.masked_fill_(padding_mask, 0.0)
505
+ smoothed_loss.masked_fill_(padding_mask, 0.0)
506
+
507
+ # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
508
+ num_active_elements = padding_mask.numel() - padding_mask.long().sum()
509
+ nll_loss = nll_loss.sum() / num_active_elements
510
+ smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
511
+ return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
512
+
513
+
514
+ def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
515
+ """
516
+ Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
517
+ lengths. To do this, the indices are:
518
+
519
+ - randomly permuted
520
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
521
+ - sorted by length in each mega-batch
522
+
523
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
524
+ maximum length placed first, so that an OOM happens sooner rather than later.
525
+ """
526
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
527
+ if mega_batch_mult is None:
528
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
529
+ # Just in case, for tiny datasets
530
+ if mega_batch_mult == 0:
531
+ mega_batch_mult = 1
532
+
533
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
534
+ indices = torch.randperm(len(lengths), generator=generator)
535
+ megabatch_size = mega_batch_mult * batch_size
536
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
537
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
538
+
539
+ # The rest is to get the biggest batch first.
540
+ # Since each megabatch is sorted by descending length, the longest element is the first
541
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
542
+ max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
543
+ # Switch to put the longest element in first position
544
+ megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
545
+
546
+ return [i for megabatch in megabatches for i in megabatch]
547
+
548
+
549
+ class LengthGroupedSampler(Sampler):
550
+ r"""
551
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
552
+ keeping a bit of randomness.
553
+ """
554
+
555
+ def __init__(
556
+ self,
557
+ batch_size: int,
558
+ dataset: Optional[Dataset] = None,
559
+ lengths: Optional[List[int]] = None,
560
+ model_input_name: Optional[str] = None,
561
+ generator=None,
562
+ ):
563
+ if dataset is None and lengths is None:
564
+ raise ValueError("One of dataset and lengths must be provided.")
565
+
566
+ self.batch_size = batch_size
567
+ if lengths is None:
568
+ model_input_name = model_input_name if model_input_name is not None else "input_ids"
569
+ if (
570
+ not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
571
+ or model_input_name not in dataset[0]
572
+ ):
573
+ raise ValueError(
574
+ "Can only automatically infer lengths for datasets whose items are dictionaries with an "
575
+ f"'{model_input_name}' key."
576
+ )
577
+ lengths = [len(feature[model_input_name]) for feature in dataset]
578
+ elif isinstance(lengths, torch.Tensor):
579
+ logger.info(
580
+ "If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..."
581
+ )
582
+ lengths = lengths.tolist()
583
+
584
+ self.lengths = lengths
585
+ self.generator = generator
586
+
587
+ def __len__(self):
588
+ return len(self.lengths)
589
+
590
+ def __iter__(self):
591
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
592
+ return iter(indices)
593
+
594
+
595
+ class DistributedLengthGroupedSampler(DistributedSampler):
596
+ r"""
597
+ Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
598
+ length while keeping a bit of randomness.
599
+ """
600
+
601
+ # Copied and adapted from PyTorch DistributedSampler.
602
+ def __init__(
603
+ self,
604
+ batch_size: int,
605
+ dataset: Optional[Dataset] = None,
606
+ num_replicas: Optional[int] = None,
607
+ rank: Optional[int] = None,
608
+ seed: int = 0,
609
+ drop_last: bool = False,
610
+ lengths: Optional[List[int]] = None,
611
+ model_input_name: Optional[str] = None,
612
+ ):
613
+ if dataset is None and lengths is None:
614
+ raise ValueError("One of dataset and lengths must be provided.")
615
+ if num_replicas is None:
616
+ if not dist.is_available():
617
+ raise RuntimeError("Requires distributed package to be available")
618
+ num_replicas = dist.get_world_size()
619
+ if rank is None:
620
+ if not dist.is_available():
621
+ raise RuntimeError("Requires distributed package to be available")
622
+ rank = dist.get_rank()
623
+
624
+ self.batch_size = batch_size
625
+ self.num_replicas = num_replicas
626
+ self.rank = rank
627
+ self.epoch = 0
628
+ self.drop_last = drop_last
629
+
630
+ if lengths is None:
631
+ model_input_name = model_input_name if model_input_name is not None else "input_ids"
632
+ if (
633
+ not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
634
+ or model_input_name not in dataset[0]
635
+ ):
636
+ raise ValueError(
637
+ "Can only automatically infer lengths for datasets whose items are dictionaries with an "
638
+ f"'{model_input_name}' key."
639
+ )
640
+ lengths = [len(feature[model_input_name]) for feature in dataset]
641
+ elif isinstance(lengths, torch.Tensor):
642
+ logger.info(
643
+ "If lengths is a torch.Tensor, DistributedLengthGroupedSampler will be slow. Converting lengths to"
644
+ " List[int]..."
645
+ )
646
+ lengths = lengths.tolist()
647
+
648
+ self.lengths = lengths
649
+
650
+ # If the dataset length is evenly divisible by # of replicas, then there
651
+ # is no need to drop any data, since the dataset will be split equally.
652
+ if self.drop_last and len(self.lengths) % self.num_replicas != 0:
653
+ # Split to nearest available length that is evenly divisible.
654
+ # This is to ensure each rank receives the same amount of data when
655
+ # using this Sampler.
656
+ self.num_samples = math.ceil((len(self.lengths) - self.num_replicas) / self.num_replicas)
657
+ else:
658
+ self.num_samples = math.ceil(len(self.lengths) / self.num_replicas)
659
+ self.total_size = self.num_samples * self.num_replicas
660
+ self.seed = seed
661
+
662
+ def __iter__(self) -> Iterator:
663
+ # Deterministically shuffle based on epoch and seed
664
+ g = torch.Generator()
665
+ g.manual_seed(self.seed + self.epoch)
666
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
667
+
668
+ if not self.drop_last:
669
+ # add extra samples to make it evenly divisible
670
+ indices += indices[: (self.total_size - len(indices))]
671
+ else:
672
+ # remove tail of data to make it evenly divisible.
673
+ indices = indices[: self.total_size]
674
+ assert len(indices) == self.total_size
675
+
676
+ # subsample
677
+ indices = indices[self.rank : self.total_size : self.num_replicas]
678
+ assert len(indices) == self.num_samples
679
+
680
+ return iter(indices)
681
+
682
+
683
+ class ShardSampler(Sampler):
684
+ """
685
+ Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch
686
+ size 4, the first two batches are `[0, 1, 2, 3, 4, 5, 6, 7]` and `[8, 9, 10, 11, 12, 13, 14, 15]`, which shard into
687
+ `[0, 1, 2, 3]` and `[8, 9, 10, 11]` for GPU-0 and `[4, 5, 6, 7]` and `[12, 13, 14, 15]` for GPU-1.
688
+
689
+ The sampler thus yields `[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and `[4, 5, 6, 7, 12, 13, 14, 15]` on GPU-1.
690
+ """
691
+
692
+ def __init__(
693
+ self,
694
+ dataset: Dataset,
695
+ batch_size: int = 1,
696
+ drop_last: bool = False,
697
+ num_processes: int = 1,
698
+ process_index: int = 0,
699
+ ):
700
+ self.dataset = dataset
701
+ self.batch_size = batch_size
702
+ self.drop_last = drop_last
703
+ self.num_processes = num_processes
704
+ self.process_index = process_index
705
+
706
+ self.total_batch_size = total_batch_size = batch_size * num_processes
707
+
708
+ num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size)
709
+ self.total_num_samples = num_batches * total_batch_size
710
+
711
+ def __iter__(self):
712
+ indices = list(range(len(self.dataset)))
713
+
714
+ # Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset
715
+ # and it needs to be done several times.
716
+ while len(indices) < self.total_num_samples:
717
+ indices += indices[: (self.total_num_samples - len(indices))]
718
+
719
+ result = []
720
+ for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size):
721
+ result += indices[batch_start : batch_start + self.batch_size]
722
+
723
+ return iter(result)
724
+
725
+ def __len__(self):
726
+ # Each shard only sees a fraction of total_num_samples.
727
+ return self.total_num_samples // self.num_processes
728
+
729
+
730
+ class IterableDatasetShard(IterableDataset):
731
+ """
732
+ Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
733
+ always yield a number of samples that is a round multiple of the actual batch size (which is `batch_size x
734
+ num_processes`). Depending on the value of the `drop_last` attribute, it will either stop the iteration at the
735
+ first batch that would be too small or loop with indices from the beginning.
736
+
737
+ On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch size of
738
+ 2:
739
+
740
+ - the shard on process 0 will yield `[0, 1, 4, 5, 8, 9]` so will see batches `[0, 1]`, `[4, 5]`, `[8, 9]`
741
+ - the shard on process 1 will yield `[2, 3, 6, 7, 10, 11]` so will see batches `[2, 3]`, `[6, 7]`, `[10, 11]`
742
+
743
+ <Tip warning={true}>
744
+
745
+ If your IterableDataset implements some randomization that needs to be applied the same way on all processes
746
+ (for instance, a shuffling), you should use a `torch.Generator` in a `generator` attribute of the `dataset` to
747
+ generate your random numbers and call the [`~trainer_pt_utils.IterableDatasetShard.set_epoch`] method of this
748
+ object. It will set the seed of this `generator` to `seed + epoch` on all processes before starting the
749
+ iteration. Alternatively, you can also implement a `set_epoch()` method in your iterable dataset to deal with
750
+ this.
751
+
752
+ </Tip>
753
+
754
+ Args:
755
+ dataset (`torch.utils.data.IterableDataset`):
756
+ The batch sampler to split in several shards.
757
+ batch_size (`int`, *optional*, defaults to 1):
758
+ The size of the batches per shard.
759
+ drop_last (`bool`, *optional*, defaults to `False`):
760
+ Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
761
+ beginning.
762
+ num_processes (`int`, *optional*, defaults to 1):
763
+ The number of processes running concurrently.
764
+ process_index (`int`, *optional*, defaults to 0):
765
+ The index of the current process.
766
+ seed (`int`, *optional*, defaults to 0):
767
+ A random seed that will be used for the random number generation in
768
+ [`~trainer_pt_utils.IterableDatasetShard.set_epoch`].
769
+ """
770
+
771
+ def __init__(
772
+ self,
773
+ dataset: IterableDataset,
774
+ batch_size: int = 1,
775
+ drop_last: bool = False,
776
+ num_processes: int = 1,
777
+ process_index: int = 0,
778
+ seed: int = 0,
779
+ ):
780
+ self.dataset = dataset
781
+ self.batch_size = batch_size
782
+ self.drop_last = drop_last
783
+ self.num_processes = num_processes
784
+ self.process_index = process_index
785
+ self.seed = seed
786
+ self.epoch = 0
787
+ self.num_examples = 0
788
+
789
+ def set_epoch(self, epoch):
790
+ self.epoch = epoch
791
+ if hasattr(self.dataset, "set_epoch"):
792
+ self.dataset.set_epoch(epoch)
793
+
794
+ def __iter__(self):
795
+ self.num_examples = 0
796
+ if (
797
+ not hasattr(self.dataset, "set_epoch")
798
+ and hasattr(self.dataset, "generator")
799
+ and isinstance(self.dataset.generator, torch.Generator)
800
+ ):
801
+ self.dataset.generator.manual_seed(self.seed + self.epoch)
802
+ real_batch_size = self.batch_size * self.num_processes
803
+ process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)
804
+
805
+ first_batch = None
806
+ current_batch = []
807
+ for element in self.dataset:
808
+ self.num_examples += 1
809
+ current_batch.append(element)
810
+ # Wait to have a full batch before yielding elements.
811
+ if len(current_batch) == real_batch_size:
812
+ for i in process_slice:
813
+ yield current_batch[i]
814
+ if first_batch is None:
815
+ first_batch = current_batch.copy()
816
+ current_batch = []
817
+
818
+ # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
819
+ if not self.drop_last and len(current_batch) > 0:
820
+ if first_batch is None:
821
+ first_batch = current_batch.copy()
822
+ while len(current_batch) < real_batch_size:
823
+ current_batch += first_batch
824
+ for i in process_slice:
825
+ yield current_batch[i]
826
+
827
+ def __len__(self):
828
+ # Will raise an error if the underlying dataset is not sized.
829
+ if self.drop_last:
830
+ return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
831
+ else:
832
+ return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
833
+
834
+
835
+ # In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
836
+ # helper methods here
837
+
838
+
839
+ def _get_learning_rate(self):
840
+ if self.deepspeed:
841
+ # with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
842
+ # not run for the first few dozen steps while loss scale is too large, and thus during
843
+ # that time `get_last_lr` will fail if called during that warm up stage, so work around it:
844
+ try:
845
+ last_lr = self.lr_scheduler.get_last_lr()[0]
846
+ except AssertionError as e:
847
+ if "need to call step" in str(e):
848
+ logger.warning("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
849
+ last_lr = 0
850
+ else:
851
+ raise
852
+ else:
853
+ last_lr = self.lr_scheduler.get_last_lr()[0]
854
+ if torch.is_tensor(last_lr):
855
+ last_lr = last_lr.item()
856
+ return last_lr
857
+
858
+
859
+ def _secs2timedelta(secs):
860
+ """
861
+ convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimals
862
+ """
863
+
864
+ msec = int(abs(secs - int(secs)) * 100)
865
+ return f"{datetime.timedelta(seconds=int(secs))}.{msec:02d}"
866
+
867
+
868
+ def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
869
+ """
870
+ Reformat Trainer metrics values to a human-readable format
871
+
872
+ Args:
873
+ metrics (`Dict[str, float]`):
874
+ The metrics returned from train/evaluate/predict
875
+
876
+ Returns:
877
+ metrics (`Dict[str, float]`): The reformatted metrics
878
+ """
879
+
880
+ metrics_copy = metrics.copy()
881
+ for k, v in metrics_copy.items():
882
+ if "_mem_" in k:
883
+ metrics_copy[k] = f"{ v >> 20 }MB"
884
+ elif "_runtime" in k:
885
+ metrics_copy[k] = _secs2timedelta(v)
886
+ elif k == "total_flos":
887
+ metrics_copy[k] = f"{ int(v) >> 30 }GF"
888
+ elif type(metrics_copy[k]) == float:
889
+ metrics_copy[k] = round(v, 4)
890
+
891
+ return metrics_copy
892
+
893
+
894
+ def log_metrics(self, split, metrics):
895
+ """
896
+ Log metrics in a specially formatted way
897
+
898
+ Under distributed environment this is done only for a process with rank 0.
899
+
900
+ Args:
901
+ split (`str`):
902
+ Mode/split name: one of `train`, `eval`, `test`
903
+ metrics (`Dict[str, float]`):
904
+ The metrics returned from train/evaluate/predictmetrics: metrics dict
905
+
906
+ Notes on memory reports:
907
+
908
+ In order to get memory usage report you need to install `psutil`. You can do that with `pip install psutil`.
909
+
910
+ Now when this method is run, you will see a report that will include: :
911
+
912
+ ```
913
+ init_mem_cpu_alloc_delta = 1301MB
914
+ init_mem_cpu_peaked_delta = 154MB
915
+ init_mem_gpu_alloc_delta = 230MB
916
+ init_mem_gpu_peaked_delta = 0MB
917
+ train_mem_cpu_alloc_delta = 1345MB
918
+ train_mem_cpu_peaked_delta = 0MB
919
+ train_mem_gpu_alloc_delta = 693MB
920
+ train_mem_gpu_peaked_delta = 7MB
921
+ ```
922
+
923
+ **Understanding the reports:**
924
+
925
+ - the first segment, e.g., `train__`, tells you which stage the metrics are for. Reports starting with `init_`
926
+ will be added to the first stage that gets run. So that if only evaluation is run, the memory usage for the
927
+ `__init__` will be reported along with the `eval_` metrics.
928
+ - the third segment, is either `cpu` or `gpu`, tells you whether it's the general RAM or the gpu0 memory
929
+ metric.
930
+ - `*_alloc_delta` - is the difference in the used/allocated memory counter between the end and the start of the
931
+ stage - it can be negative if a function released more memory than it allocated.
932
+ - `*_peaked_delta` - is any extra memory that was consumed and then freed - relative to the current allocated
933
+ memory counter - it is never negative. When you look at the metrics of any stage you add up `alloc_delta` +
934
+ `peaked_delta` and you know how much memory was needed to complete that stage.
935
+
936
+ The reporting happens only for process of rank 0 and gpu 0 (if there is a gpu). Typically this is enough since the
937
+ main process does the bulk of work, but it could be not quite so if model parallel is used and then other GPUs may
938
+ use a different amount of gpu memory. This is also not the same under DataParallel where gpu0 may require much more
939
+ memory than the rest since it stores the gradient and optimizer states for all participating GPUS. Perhaps in the
940
+ future these reports will evolve to measure those too.
941
+
942
+ The CPU RAM metric measures RSS (Resident Set Size) includes both the memory which is unique to the process and the
943
+ memory shared with other processes. It is important to note that it does not include swapped out memory, so the
944
+ reports could be imprecise.
945
+
946
+ The CPU peak memory is measured using a sampling thread. Due to python's GIL it may miss some of the peak memory if
947
+ that thread didn't get a chance to run when the highest memory was used. Therefore this report can be less than
948
+ reality. Using `tracemalloc` would have reported the exact peak memory, but it doesn't report memory allocations
949
+ outside of python. So if some C++ CUDA extension allocated its own memory it won't be reported. And therefore it
950
+ was dropped in favor of the memory sampling approach, which reads the current process memory usage.
951
+
952
+ The GPU allocated and peak memory reporting is done with `torch.cuda.memory_allocated()` and
953
+ `torch.cuda.max_memory_allocated()`. This metric reports only "deltas" for pytorch-specific allocations, as
954
+ `torch.cuda` memory management system doesn't track any memory allocated outside of pytorch. For example, the very
955
+ first cuda call typically loads CUDA kernels, which may take from 0.5 to 2GB of GPU memory.
956
+
957
+ Note that this tracker doesn't account for memory allocations outside of [`Trainer`]'s `__init__`, `train`,
958
+ `evaluate` and `predict` calls.
959
+
960
+ Because `evaluation` calls may happen during `train`, we can't handle nested invocations because
961
+ `torch.cuda.max_memory_allocated` is a single counter, so if it gets reset by a nested eval call, `train`'s tracker
962
+ will report incorrect info. If this [pytorch issue](https://github.com/pytorch/pytorch/issues/16266) gets resolved
963
+ it will be possible to change this class to be re-entrant. Until then we will only track the outer level of
964
+ `train`, `evaluate` and `predict` methods. Which means that if `eval` is called during `train`, it's the latter
965
+ that will account for its memory usage and that of the former.
966
+
967
+ This also means that if any other tool that is used along the [`Trainer`] calls
968
+ `torch.cuda.reset_peak_memory_stats`, the gpu peak memory stats could be invalid. And the [`Trainer`] will disrupt
969
+ the normal behavior of any such tools that rely on calling `torch.cuda.reset_peak_memory_stats` themselves.
970
+
971
+ For best performance you may want to consider turning the memory profiling off for production runs.
972
+ """
973
+ if not self.is_world_process_zero():
974
+ return
975
+
976
+ print(f"***** {split} metrics *****")
977
+ metrics_formatted = self.metrics_format(metrics)
978
+ k_width = max(len(str(x)) for x in metrics_formatted.keys())
979
+ v_width = max(len(str(x)) for x in metrics_formatted.values())
980
+ for key in sorted(metrics_formatted.keys()):
981
+ print(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
982
+
983
+
984
+ def save_metrics(self, split, metrics, combined=True):
985
+ """
986
+ Save metrics into a json file for that split, e.g. `train_results.json`.
987
+
988
+ Under distributed environment this is done only for a process with rank 0.
989
+
990
+ Args:
991
+ split (`str`):
992
+ Mode/split name: one of `train`, `eval`, `test`, `all`
993
+ metrics (`Dict[str, float]`):
994
+ The metrics returned from train/evaluate/predict
995
+ combined (`bool`, *optional*, defaults to `True`):
996
+ Creates combined metrics by updating `all_results.json` with metrics of this call
997
+
998
+ To understand the metrics please read the docstring of [`~Trainer.log_metrics`]. The only difference is that raw
999
+ unformatted numbers are saved in the current method.
1000
+
1001
+ """
1002
+ if not self.is_world_process_zero():
1003
+ return
1004
+
1005
+ path = os.path.join(self.args.output_dir, f"{split}_results.json")
1006
+ with open(path, "w") as f:
1007
+ json.dump(metrics, f, indent=4, sort_keys=True)
1008
+
1009
+ if combined:
1010
+ path = os.path.join(self.args.output_dir, "all_results.json")
1011
+ if os.path.exists(path):
1012
+ with open(path, "r") as f:
1013
+ all_metrics = json.load(f)
1014
+ else:
1015
+ all_metrics = {}
1016
+
1017
+ all_metrics.update(metrics)
1018
+ with open(path, "w") as f:
1019
+ json.dump(all_metrics, f, indent=4, sort_keys=True)
1020
+
1021
+
1022
+ def save_state(self):
1023
+ """
1024
+ Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model
1025
+
1026
+ Under distributed environment this is done only for a process with rank 0.
1027
+ """
1028
+ if not self.is_world_process_zero():
1029
+ return
1030
+
1031
+ path = os.path.join(self.args.output_dir, "trainer_state.json")
1032
+ self.state.save_to_json(path)
1033
+
1034
+
1035
+ def get_parameter_names(model, forbidden_layer_types):
1036
+ """
1037
+ Returns the names of the model parameters that are not inside a forbidden layer.
1038
+ """
1039
+ result = []
1040
+ for name, child in model.named_children():
1041
+ result += [
1042
+ f"{name}.{n}"
1043
+ for n in get_parameter_names(child, forbidden_layer_types)
1044
+ if not isinstance(child, tuple(forbidden_layer_types))
1045
+ ]
1046
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
1047
+ result += list(model._parameters.keys())
1048
+ return result
1049
+
1050
+
1051
+ def get_module_class_from_name(module, name):
1052
+ """
1053
+ Gets a class from a module by its name.
1054
+
1055
+ Args:
1056
+ module (`torch.nn.Module`): The module to get the class from.
1057
+ name (`str`): The name of the class.
1058
+ """
1059
+ modules_children = list(module.children())
1060
+ if module.__class__.__name__ == name:
1061
+ return module.__class__
1062
+ elif len(modules_children) == 0:
1063
+ return
1064
+ else:
1065
+ for child_module in modules_children:
1066
+ module_class = get_module_class_from_name(child_module, name)
1067
+ if module_class is not None:
1068
+ return module_class
1069
+
1070
+
1071
+ if is_sagemaker_mp_enabled():
1072
+ import smdistributed.modelparallel.torch as smp
1073
+
1074
+ @smp.step()
1075
+ def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
1076
+ outputs = model(**inputs)
1077
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
1078
+ loss /= gradient_accumulation_steps
1079
+ model.backward(loss)
1080
+ return loss
1081
+
1082
+ @smp.step()
1083
+ def smp_forward_only(model, inputs):
1084
+ return model(**inputs)
1085
+
1086
+ def smp_gather(tensor):
1087
+ if isinstance(tensor, (list, tuple)):
1088
+ return type(tensor)(smp_gather(t) for t in tensor)
1089
+ elif isinstance(tensor, dict):
1090
+ return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
1091
+ elif not isinstance(tensor, torch.Tensor):
1092
+ raise TypeError(
1093
+ f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
1094
+ )
1095
+ all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
1096
+ all_tensors = [atleast_1d(t) for t in all_tensors]
1097
+ return torch.cat([t.cpu() for t in all_tensors], dim=0)
1098
+
1099
+ def smp_nested_concat(tensor):
1100
+ if isinstance(tensor, (list, tuple)):
1101
+ return type(tensor)(smp_nested_concat(t) for t in tensor)
1102
+ elif isinstance(tensor, dict):
1103
+ return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()})
1104
+ # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
1105
+ # which is also the name of the decorator so Python is confused.
1106
+ return tensor.concat().detach().cpu()