HarryLee commited on
Commit
ef8be5c
1 Parent(s): fa57f25

initial cmomit

Browse files
Files changed (1) hide show
  1. trainer.py +1531 -0
trainer.py ADDED
@@ -0,0 +1,1531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Train a network across multiple GPUs.
8
+ """
9
+
10
+ import contextlib
11
+ import logging
12
+ import sys
13
+ import time
14
+ from argparse import Namespace
15
+ from itertools import chain
16
+ from typing import Any, Dict, List
17
+
18
+ import torch
19
+ from fairseq import models, optim, utils
20
+ from fairseq.dataclass.configs import FairseqConfig
21
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
22
+ from fairseq.distributed import utils as distributed_utils
23
+ from fairseq.file_io import PathManager
24
+ from fairseq.logging import meters, metrics
25
+ from fairseq.models.ema import build_ema
26
+ from fairseq.nan_detector import NanDetector
27
+ from fairseq.optim import lr_scheduler
28
+ from omegaconf import OmegaConf
29
+
30
+ from utils import checkpoint_utils
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class Trainer(object):
36
+ """Main class for data parallel training.
37
+
38
+ This class supports synchronous distributed data parallel training,
39
+ where multiple workers each have a full model replica and gradients
40
+ are accumulated across workers before each update. We use
41
+ :class:`~torch.nn.parallel.DistributedDataParallel` to handle
42
+ communication of the gradients across workers.
43
+ """
44
+
45
+ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None):
46
+
47
+ if isinstance(cfg, Namespace):
48
+ logger.warning(
49
+ "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
50
+ )
51
+ cfg = convert_namespace_to_omegaconf(cfg)
52
+
53
+ self.cfg = cfg
54
+ self.task = task
55
+
56
+ # catalog shared parameters
57
+ shared_params = _catalog_shared_params(model)
58
+ self.tpu = cfg.common.tpu
59
+ self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu
60
+ if self.cuda:
61
+ self.device = torch.device("cuda")
62
+ elif self.tpu:
63
+ self.device = utils.get_tpu_device()
64
+ else:
65
+ self.device = torch.device("cpu")
66
+
67
+ if self.is_fsdp:
68
+ import fairscale
69
+ if self.cfg.common.bf16:
70
+ raise ValueError(
71
+ "FullyShardedDataParallel is not compatible with --bf16 or "
72
+ "--memory-efficient-bf16"
73
+ )
74
+ if self.cfg.distributed_training.zero_sharding != "none":
75
+ raise ValueError(
76
+ "FullyShardedDataParallel is not compatible with --zero-sharding "
77
+ "option (it's already built in)"
78
+ )
79
+ if max(self.cfg.optimization.update_freq) > 1 and fairscale.__version__ < "0.4.0":
80
+ raise RuntimeError(
81
+ "Please update to fairscale 0.4.0 or newer when combining "
82
+ "--update-freq with FullyShardedDataParallel"
83
+ )
84
+ else:
85
+ if (
86
+ hasattr(self.cfg.distributed_training, "cpu_offload")
87
+ and self.cfg.distributed_training.cpu_offload
88
+ ):
89
+ raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded")
90
+
91
+ # copy model and criterion to current device/dtype
92
+ self._criterion = criterion
93
+ self._model = model
94
+ if not self.is_fsdp:
95
+ if cfg.common.fp16:
96
+ assert not cfg.common.amp, "Cannot use fp16 and AMP together"
97
+ self._criterion = self._criterion.half()
98
+ self._model = self._model.half()
99
+ elif cfg.common.bf16:
100
+ self._criterion = self._criterion.to(dtype=torch.bfloat16)
101
+ self._model = self._model.to(dtype=torch.bfloat16)
102
+ elif cfg.common.amp:
103
+ self._amp_retries = 0
104
+ if (
105
+ not cfg.distributed_training.pipeline_model_parallel
106
+ # the DistributedFairseqModel wrapper will handle moving to device,
107
+ # so only handle cases which don't use the wrapper
108
+ and not self.use_distributed_wrapper
109
+ ):
110
+ self._criterion = self._criterion.to(device=self.device)
111
+ self._model = self._model.to(device=self.device)
112
+ self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
113
+ self.last_device = None
114
+ if self.cuda and self.pipeline_model_parallel:
115
+ self.last_device = torch.device(
116
+ cfg.distributed_training.pipeline_devices[-1]
117
+ )
118
+
119
+ # check that shared parameters are preserved after device transfer
120
+ for shared_param in shared_params:
121
+ ref = _get_module_by_path(self._model, shared_param[0])
122
+ for path in shared_param[1:]:
123
+ logger.info(
124
+ "detected shared parameter: {} <- {}".format(shared_param[0], path)
125
+ )
126
+ _set_module_by_path(self._model, path, ref)
127
+
128
+ self._dummy_batch = None # indicates we don't have a dummy batch at first
129
+ self._lr_scheduler = None
130
+ self._num_updates = 0
131
+ self._num_xla_compiles = 0 # for TPUs
132
+ self._optim_history = None
133
+ self._optimizer = None
134
+ self._warn_once = set()
135
+ self._wrapped_criterion = None
136
+ self._wrapped_model = None
137
+ self._ema = None
138
+
139
+ # TODO(myleott): support tpu
140
+ if self.cuda and self.data_parallel_world_size > 1:
141
+ self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
142
+ else:
143
+ self._grad_norm_buf = None
144
+
145
+ self.quantizer = quantizer
146
+ if self.quantizer is not None:
147
+ self.quantizer.set_trainer(self)
148
+
149
+ # get detailed cuda environment
150
+ if self.cuda:
151
+ self.cuda_env = utils.CudaEnvironment()
152
+ if self.data_parallel_world_size > 1:
153
+ self.cuda_env_arr = distributed_utils.all_gather_list(
154
+ self.cuda_env, group=distributed_utils.get_global_group()
155
+ )
156
+ else:
157
+ self.cuda_env_arr = [self.cuda_env]
158
+ if self.data_parallel_rank == 0:
159
+ utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
160
+ else:
161
+ self.cuda_env = None
162
+ self.cuda_env_arr = None
163
+
164
+ metrics.log_start_time("wall", priority=790, round=0)
165
+
166
+ self._start_time = time.time()
167
+ self._previous_training_time = 0
168
+ self._cumulative_training_time = None
169
+
170
+ def reinitialize(self):
171
+ """Reinitialize the Trainer, typically after model params change."""
172
+ self._lr_scheduler = None
173
+ self._optimizer = None
174
+ self._wrapped_criterion = None
175
+ self._wrapped_model = None
176
+
177
+ @property
178
+ def data_parallel_world_size(self):
179
+ if self.cfg.distributed_training.distributed_world_size == 1:
180
+ return 1
181
+ return distributed_utils.get_data_parallel_world_size()
182
+
183
+ @property
184
+ def data_parallel_process_group(self):
185
+ return distributed_utils.get_data_parallel_group()
186
+
187
+ @property
188
+ def data_parallel_rank(self):
189
+ if self.cfg.distributed_training.distributed_world_size == 1:
190
+ return 0
191
+ return distributed_utils.get_data_parallel_rank()
192
+
193
+ @property
194
+ def is_data_parallel_master(self):
195
+ # NOTE: this returns true for all model parallel replicas with data
196
+ # parallel rank 0
197
+ return self.data_parallel_rank == 0
198
+
199
+ @property
200
+ def use_distributed_wrapper(self) -> bool:
201
+ return (
202
+ self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf
203
+ ) or (
204
+ self.is_fsdp and self.cfg.distributed_training.cpu_offload
205
+ )
206
+
207
+ @property
208
+ def should_save_checkpoint_on_current_rank(self) -> bool:
209
+ """Indicates whether to save checkpoints on the current DDP rank."""
210
+ if (
211
+ self.is_fsdp and self.cfg.distributed_training.use_sharded_state
212
+ ) or getattr(self.cfg.model, "base_layers", 0) > 0:
213
+ return True
214
+ else:
215
+ return self.is_data_parallel_master
216
+
217
+ @property
218
+ def always_call_state_dict_during_save_checkpoint(self) -> bool:
219
+ if self.is_fsdp and not self.cfg.distributed_training.use_sharded_state:
220
+ # FSDP calls communication collective when consolidating checkpoints
221
+ return True
222
+ else:
223
+ return False
224
+
225
+ @property
226
+ def checkpoint_suffix(self) -> str:
227
+ """Suffix to add to the checkpoint file name."""
228
+ if self.is_fsdp and self.cfg.distributed_training.use_sharded_state:
229
+ return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(
230
+ self.data_parallel_rank
231
+ )
232
+ else:
233
+ return self.cfg.checkpoint.checkpoint_suffix or ""
234
+
235
+ @property
236
+ def criterion(self):
237
+ if self._wrapped_criterion is None:
238
+ if utils.has_parameters(self._criterion) and self.use_distributed_wrapper:
239
+ self._wrapped_criterion = models.DistributedFairseqModel(
240
+ self.cfg.distributed_training,
241
+ self._criterion,
242
+ process_group=self.data_parallel_process_group,
243
+ device=self.device,
244
+ )
245
+ else:
246
+ self._wrapped_criterion = self._criterion
247
+ return self._wrapped_criterion
248
+
249
+ @property
250
+ def model(self):
251
+ if self._wrapped_model is None:
252
+ if self.use_distributed_wrapper:
253
+ self._wrapped_model = models.DistributedFairseqModel(
254
+ self.cfg.distributed_training,
255
+ self._model,
256
+ process_group=self.data_parallel_process_group,
257
+ device=self.device,
258
+ )
259
+ else:
260
+ self._wrapped_model = self._model
261
+ return self._wrapped_model
262
+
263
+ @property
264
+ def ema(self):
265
+ if self._ema is None:
266
+ self._build_ema()
267
+ return self._ema
268
+
269
+ def _build_ema(self):
270
+ if self.cfg.ema.store_ema:
271
+ self._ema = build_ema(self._model, self.cfg.ema, self.device)
272
+ logger.info(
273
+ "Exponential Moving Average Shadow Model is initialized."
274
+ )
275
+
276
+ @property
277
+ def optimizer(self):
278
+ if self._optimizer is None:
279
+ self._build_optimizer()
280
+ return self._optimizer
281
+
282
+ @property
283
+ def lr_scheduler(self):
284
+ if self._lr_scheduler is None:
285
+ self._build_optimizer() # this will initialize self._lr_scheduler
286
+ return self._lr_scheduler
287
+
288
+ def _build_optimizer(self):
289
+ params = list(
290
+ filter(
291
+ lambda p: p.requires_grad,
292
+ chain(self.model.parameters(), self.criterion.parameters()),
293
+ )
294
+ )
295
+
296
+ if self.is_fsdp and self.cfg.common.fp16:
297
+ # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
298
+ # mostly for the grad scaling. But if we don't have the
299
+ # --memory-efficient-fp16 flag set, then we're effectively doing
300
+ # regular --fp16 and can allow the use of optimizers that would
301
+ # otherwise be unsupported by MemoryEfficientFP16Optimizer.
302
+ allow_unsupported = not self.cfg.common.memory_efficient_fp16
303
+ self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
304
+ self.cfg, params, allow_unsupported=allow_unsupported
305
+ )
306
+ elif self.cfg.common.fp16 or self.cfg.common.bf16 or self.cfg.common.amp:
307
+ if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
308
+ logger.info(
309
+ "NOTE: your device does NOT support faster training with --fp16 or --amp, "
310
+ "please switch to FP32 which is likely to be faster"
311
+ )
312
+ if (
313
+ self.cfg.common.memory_efficient_fp16
314
+ or self.cfg.common.memory_efficient_bf16
315
+ ):
316
+ self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
317
+ self.cfg, params
318
+ )
319
+ elif self.cfg.common.amp:
320
+ self._optimizer = optim.AMPOptimizer.build_optimizer(self.cfg, params)
321
+ else:
322
+ self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params)
323
+ else:
324
+ if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
325
+ logger.info("NOTE: your device may support faster training with --fp16 or --amp")
326
+ self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
327
+
328
+ if self.is_fsdp:
329
+ assert (
330
+ not self.cfg.optimization.use_bmuf
331
+ ), "--ddp-backend=fully_sharded is not compatible with BMUF"
332
+ assert self._optimizer.supports_flat_params, (
333
+ "--ddp-backend=fully_sharded is only compatible with pointwise "
334
+ "optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). "
335
+ "However, the sharding will result in slightly different results when "
336
+ "using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)"
337
+ )
338
+
339
+ if self.cfg.optimization.use_bmuf:
340
+ self._optimizer = optim.FairseqBMUF(
341
+ self.cfg.bmuf,
342
+ self._optimizer,
343
+ )
344
+
345
+ if self.cfg.distributed_training.zero_sharding == "os":
346
+ if (
347
+ self.cfg.common.fp16
348
+ and not self.cfg.common.memory_efficient_fp16
349
+ and not self.cfg.common.memory_efficient_bf16
350
+ ) and not self.cfg.common.fp16_no_flatten_grads:
351
+ raise ValueError(
352
+ "ZeRO is incomptabile with fp16 and flattened grads. "
353
+ "Please use --fp16-no-flatten-grads"
354
+ )
355
+ else:
356
+ optim.shard_(self._optimizer, self.data_parallel_process_group)
357
+
358
+ # We should initialize the learning rate scheduler immediately after
359
+ # building the optimizer, so that the initial learning rate is set.
360
+ self._lr_scheduler = lr_scheduler.build_lr_scheduler(
361
+ self.cfg.lr_scheduler,
362
+ self.optimizer,
363
+ )
364
+ self._lr_scheduler.step_update(0)
365
+
366
+ @property
367
+ def is_fsdp(self):
368
+ return self.cfg.distributed_training.ddp_backend == "fully_sharded"
369
+
370
+ def consolidate_optimizer(self):
371
+ """For OSS, we need to consolidate the state dict."""
372
+ if self.cfg.checkpoint.no_save_optimizer_state:
373
+ return
374
+ self._gathered_optim_state = None
375
+ if hasattr(self.optimizer.optimizer, "consolidate_state_dict"):
376
+ self.optimizer.optimizer.consolidate_state_dict()
377
+ elif self.is_fsdp and not self.model.use_sharded_state:
378
+ st = self.model.gather_full_optim_state_dict(
379
+ self.optimizer
380
+ ) # only returns on rank 0
381
+ self._gathered_optim_state = st
382
+
383
+ def state_dict(self):
384
+ state_dict = {
385
+ "args": None, # legacy
386
+ "cfg": (
387
+ OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True)
388
+ if OmegaConf.is_config(self.cfg)
389
+ else self.cfg
390
+ ),
391
+ "model": self.model.state_dict(),
392
+ "criterion": (
393
+ self.criterion.state_dict()
394
+ if utils.has_parameters(self.criterion)
395
+ else None
396
+ ),
397
+ "optimizer_history": (self._optim_history or [])
398
+ + [
399
+ {
400
+ "criterion_name": self.get_criterion().__class__.__name__,
401
+ "optimizer_name": self.optimizer.__class__.__name__,
402
+ "lr_scheduler_state": self.lr_scheduler.state_dict(),
403
+ "num_updates": self.get_num_updates(),
404
+ }
405
+ ],
406
+ "task_state": self.task.state_dict() if self.task is not None else {},
407
+ "extra_state": {
408
+ "metrics": metrics.state_dict(),
409
+ "previous_training_time": self.cumulative_training_time(),
410
+ },
411
+ }
412
+ if self.cfg.ema.store_ema:
413
+ # Save EMA model state as extra state
414
+ state_dict["extra_state"]["ema"] = self.ema.get_model().state_dict()
415
+ if self.cfg.ema.ema_fp32:
416
+ # Save EMA params in fp32
417
+ state_dict["extra_state"]["ema_fp32_params"] = self.ema.fp32_params
418
+ if not self.cfg.checkpoint.no_save_optimizer_state:
419
+ if self._gathered_optim_state is not None:
420
+ state_dict["last_optimizer_state"] = self._gathered_optim_state
421
+ self._gathered_optim_state = None
422
+ else:
423
+ state_dict["last_optimizer_state"] = self.optimizer.state_dict()
424
+ if self.is_fsdp:
425
+ # save meta data for recombining checkpoint upon loading
426
+ state_dict["fsdp_metadata"] = self.model.local_metadata_dict()
427
+ return state_dict
428
+
429
+ def save_checkpoint(self, filename, extra_state):
430
+ """Save all training state in a checkpoint file."""
431
+ logger.info(f"Saving checkpoint to {filename}")
432
+ # call state_dict on all ranks in case it needs internal communication
433
+ state_dict = utils.move_to_cpu(self.state_dict())
434
+ state_dict["extra_state"].update(extra_state)
435
+ if self.should_save_checkpoint_on_current_rank:
436
+ checkpoint_utils.torch_persistent_save(
437
+ state_dict,
438
+ filename,
439
+ async_write=self.cfg.checkpoint.write_checkpoints_asynchronously,
440
+ )
441
+ logger.info(f"Finished saving checkpoint to {filename}")
442
+
443
+ def load_checkpoint(
444
+ self,
445
+ filename,
446
+ reset_optimizer=False,
447
+ reset_lr_scheduler=False,
448
+ optimizer_overrides=None,
449
+ reset_meters=False,
450
+ ):
451
+ """
452
+ Load all training state from a checkpoint file.
453
+ rank = 0 will load the checkpoint, and then broadcast it to all
454
+ other ranks.
455
+ """
456
+ extra_state, self._optim_history, last_optim_state = None, [], None
457
+
458
+ logger.info(f"Preparing to load checkpoint {filename}")
459
+ is_distributed = self.data_parallel_world_size > 1
460
+ bexists = PathManager.isfile(filename)
461
+ if bexists:
462
+ load_on_all_ranks = (
463
+ self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks
464
+ # TPUs don't support broadcast yet, so load checkpoints
465
+ # on every worker for now
466
+ or self.tpu
467
+ # FSDP requires loading checkpoint shards on all ranks
468
+ or (self.is_fsdp and self.cfg.distributed_training.use_sharded_state)
469
+ or getattr(self.cfg.model, "base_layers", 0) > 0
470
+ )
471
+
472
+ if load_on_all_ranks or self.data_parallel_rank == 0:
473
+ state = checkpoint_utils.load_checkpoint_to_cpu(
474
+ filename, load_on_all_ranks=load_on_all_ranks
475
+ )
476
+ last_optim_state = state.get("last_optimizer_state", None)
477
+
478
+ # If doing zero_sharding, do not broadcast global optimizer
479
+ # state. Later we will broadcast sharded states to each rank
480
+ # to avoid memory from exploding.
481
+ if (
482
+ not load_on_all_ranks
483
+ and self.cfg.distributed_training.zero_sharding == "os"
484
+ and "last_optimizer_state" in state
485
+ and is_distributed
486
+ ):
487
+ state["last_optimizer_state"] = "SHARDED"
488
+ else:
489
+ last_optim_state = None
490
+ state = None
491
+
492
+ if is_distributed and not load_on_all_ranks:
493
+ state = distributed_utils.broadcast_object(
494
+ state,
495
+ src_rank=0,
496
+ group=self.data_parallel_process_group,
497
+ dist_device=self.device,
498
+ )
499
+ if self.data_parallel_rank > 0:
500
+ last_optim_state = state.get("last_optimizer_state", None)
501
+
502
+ # load model parameters
503
+ try:
504
+ if self.cfg.checkpoint.use_ema_weights_to_init_param and "extra_state" in state and "ema" in state["extra_state"]:
505
+ logger.info("use_ema_weights_to_init_param = True, will use EMA weights in the ckpt to init the model param...")
506
+ ema_state_dict = state["extra_state"]["ema_fp32_params"] if "ema_fp32_params" in state["extra_state"] else state["extra_state"]["ema"]
507
+ self.model.load_state_dict(
508
+ ema_state_dict, strict=True, model_cfg=self.cfg.model
509
+ )
510
+ else:
511
+ self.model.load_state_dict(
512
+ state["model"], strict=True, model_cfg=self.cfg.model
513
+ )
514
+ # save memory for later steps
515
+ if not (self.cfg.ema.store_ema and (self.cfg.checkpoint.use_latest_weights_to_init_ema or not ("extra_state" in state and "ema" in state["extra_state"]))):
516
+ del state["model"]
517
+ if utils.has_parameters(self.get_criterion()):
518
+ self.get_criterion().load_state_dict(
519
+ state["criterion"], strict=True
520
+ )
521
+ del state["criterion"]
522
+
523
+ except Exception:
524
+ raise Exception(
525
+ "Cannot load model parameters from checkpoint {}; "
526
+ "please ensure that the architectures match.".format(filename)
527
+ )
528
+ extra_state = state["extra_state"]
529
+ self._optim_history = state["optimizer_history"]
530
+
531
+ if last_optim_state is not None and not reset_optimizer:
532
+ # rebuild optimizer after loading model, since params may have changed
533
+ self._build_optimizer()
534
+
535
+ # only reload optimizer and lr_scheduler if they match
536
+ last_optim = self._optim_history[-1]
537
+ assert (
538
+ last_optim["criterion_name"] == self.get_criterion().__class__.__name__
539
+ ), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}"
540
+ assert (
541
+ last_optim["optimizer_name"] == self.optimizer.__class__.__name__
542
+ ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}"
543
+
544
+ if not reset_lr_scheduler:
545
+ self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
546
+
547
+ if self.is_fsdp and not self.model.use_sharded_state:
548
+ # if use_sharded_state, the last_optim_state is already sharded, skip this
549
+ last_optim_state = self.model.get_shard_from_optim_state_dict(
550
+ last_optim_state
551
+ )
552
+ elif not load_on_all_ranks and is_distributed:
553
+ last_optim_state = self.optimizer.broadcast_global_state_dict(
554
+ last_optim_state
555
+ )
556
+
557
+ self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
558
+
559
+ self.set_num_updates(last_optim["num_updates"])
560
+
561
+ if extra_state is not None:
562
+ itr_state = extra_state["train_iterator"]
563
+ epoch = itr_state["epoch"]
564
+
565
+ if "previous_training_time" in extra_state:
566
+ self._previous_training_time = extra_state["previous_training_time"]
567
+ self._start_time = time.time()
568
+
569
+ self.lr_step(epoch)
570
+
571
+ if (
572
+ itr_state.get("version", 1) >= 2
573
+ and itr_state["iterations_in_epoch"] == 0
574
+ ):
575
+ # reset meters at start of epoch
576
+ reset_meters = True
577
+
578
+ if "metrics" in extra_state and not reset_meters:
579
+ metrics.load_state_dict(extra_state["metrics"])
580
+
581
+ # reset TimeMeters, since their start times don't make sense anymore
582
+ for meter in metrics.get_meters("default"):
583
+ if isinstance(meter, meters.TimeMeter):
584
+ meter.reset()
585
+
586
+ if self.cfg.ema.store_ema:
587
+ if self.cfg.checkpoint.use_latest_weights_to_init_ema or "ema" not in extra_state:
588
+ if "ema" not in extra_state:
589
+ logger.warn(
590
+ "EMA not found in checkpoint. But store_ema is True. "
591
+ "EMA is re-initialized from checkpoint."
592
+ )
593
+ elif self.cfg.checkpoint.use_latest_weights_to_init_ema:
594
+ logger.info(
595
+ "use_latest_weights_to_init_ema = True. EMA is re-initialized from checkpoint."
596
+ )
597
+ self.ema.restore(state["model"], build_fp32_params=self.cfg.ema.ema_fp32)
598
+ del state["model"]
599
+ else:
600
+ logger.info(
601
+ "Loading EMA from checkpoint"
602
+ )
603
+ self.ema.restore(extra_state["ema"], build_fp32_params=False)
604
+
605
+ if self.cfg.ema.ema_fp32:
606
+ if "ema_fp32_params" in extra_state:
607
+ logger.info(
608
+ "Loading EMA fp32 params from checkpoint"
609
+ )
610
+ self.ema.build_fp32_params(extra_state["ema_fp32_params"])
611
+ else:
612
+ logger.info(
613
+ "Building EMA fp32 params from EMA model in checkpoint"
614
+ )
615
+ self.ema.build_fp32_params()
616
+
617
+ logger.info(
618
+ "Loaded checkpoint {} (epoch {} @ {} updates)".format(
619
+ filename, epoch, self.get_num_updates()
620
+ )
621
+ )
622
+
623
+ else:
624
+ logger.info("No existing checkpoint found {}".format(filename))
625
+
626
+ return extra_state
627
+
628
+ def get_train_iterator(
629
+ self,
630
+ epoch,
631
+ combine=True,
632
+ load_dataset=True,
633
+ data_selector=None,
634
+ shard_batch_itr=True,
635
+ disable_iterator_cache=False,
636
+ ):
637
+ """Return an EpochBatchIterator over the training set for a given epoch."""
638
+ if load_dataset:
639
+ logger.info("loading train data for epoch {}".format(epoch))
640
+ self.task.load_dataset(
641
+ self.cfg.dataset.train_subset,
642
+ epoch=epoch,
643
+ combine=combine,
644
+ data_selector=data_selector,
645
+ tpu=self.tpu,
646
+ )
647
+ batch_iterator = self.task.get_batch_iterator(
648
+ dataset=self.task.dataset(self.cfg.dataset.train_subset),
649
+ max_tokens=self.cfg.dataset.max_tokens,
650
+ max_sentences=self.cfg.dataset.batch_size,
651
+ max_positions=utils.resolve_max_positions(
652
+ self.task.max_positions(),
653
+ self.model.max_positions(),
654
+ self.cfg.dataset.max_tokens,
655
+ ),
656
+ ignore_invalid_inputs=True,
657
+ required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
658
+ seed=self.cfg.common.seed,
659
+ num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
660
+ shard_id=self.data_parallel_rank if shard_batch_itr else 0,
661
+ num_workers=self.cfg.dataset.num_workers,
662
+ epoch=epoch,
663
+ data_buffer_size=self.cfg.dataset.data_buffer_size,
664
+ disable_iterator_cache=disable_iterator_cache,
665
+ )
666
+ self.reset_dummy_batch(batch_iterator.first_batch)
667
+ batch_iterator.dataset.dataset._seek()
668
+ return batch_iterator
669
+
670
+ def get_valid_iterator(
671
+ self,
672
+ subset,
673
+ disable_iterator_cache=False,
674
+ ):
675
+ """Return an EpochBatchIterator over given validation subset for a given epoch."""
676
+ self.task.dataset(subset).dataset._seek()
677
+ batch_iterator = self.task.get_batch_iterator(
678
+ dataset=self.task.dataset(subset),
679
+ max_tokens=self.cfg.dataset.max_tokens_valid,
680
+ max_sentences=self.cfg.dataset.batch_size_valid,
681
+ max_positions=utils.resolve_max_positions(
682
+ self.task.max_positions(),
683
+ self.model.max_positions(),
684
+ ),
685
+ ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
686
+ required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
687
+ seed=self.cfg.common.seed,
688
+ num_shards=self.data_parallel_world_size,
689
+ shard_id=self.data_parallel_rank,
690
+ num_workers=self.cfg.dataset.num_workers,
691
+ # always pass a fixed "epoch" to keep validation data consistent
692
+ # across training epochs
693
+ epoch=1,
694
+ data_buffer_size=self.cfg.dataset.data_buffer_size,
695
+ disable_iterator_cache=disable_iterator_cache,
696
+ )
697
+ self.reset_dummy_batch(batch_iterator.first_batch)
698
+ batch_iterator.dataset.dataset._seek()
699
+ return batch_iterator
700
+
701
+ def begin_epoch(self, epoch):
702
+ """Called at the beginning of each epoch."""
703
+ logger.info("begin training epoch {}".format(epoch))
704
+
705
+ self.lr_step_begin_epoch(epoch)
706
+
707
+ if self.quantizer is not None:
708
+ self.quantizer.begin_epoch(epoch)
709
+
710
+ # task specific setup per epoch
711
+ self.task.begin_epoch(epoch, self.get_model())
712
+
713
+ if self.tpu:
714
+ import torch_xla.core.xla_model as xm
715
+
716
+ xm.rendezvous("begin_epoch") # wait for all workers
717
+ xm.mark_step()
718
+
719
+ def begin_valid_epoch(self, epoch):
720
+ """Called at the beginning of each validation epoch."""
721
+
722
+ # task specific setup per validation epoch
723
+ self.task.begin_valid_epoch(epoch, self.get_model())
724
+
725
+ def reset_dummy_batch(self, batch):
726
+ self._dummy_batch = batch
727
+
728
+ @metrics.aggregate("train")
729
+ def train_step(self, samples, raise_oom=False):
730
+ """Do forward, backward and parameter update."""
731
+ self._set_seed()
732
+ self.model.train()
733
+ self.criterion.train()
734
+ self.zero_grad()
735
+
736
+ metrics.log_start_time("train_wall", priority=800, round=0)
737
+
738
+ # If EMA is enabled through store_ema=True
739
+ # and task.uses_ema is True, pass the EMA model as a keyword
740
+ # argument to the task.
741
+ extra_kwargs = {}
742
+ if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
743
+ extra_kwargs["ema_model"] = self.ema.get_model()
744
+
745
+ # forward and backward pass
746
+ logging_outputs, sample_size, ooms = [], 0, 0
747
+ for i, sample in enumerate(samples): # delayed update loop
748
+ sample, is_dummy_batch = self._prepare_sample(sample)
749
+
750
+ def maybe_no_sync():
751
+ """
752
+ Whenever *samples* contains more than one mini-batch, we
753
+ want to accumulate gradients locally and only call
754
+ all-reduce in the last backwards pass.
755
+ """
756
+ if (
757
+ self.data_parallel_world_size > 1
758
+ and hasattr(self.model, "no_sync")
759
+ and i < len(samples) - 1
760
+ # The no_sync context manager results in increased memory
761
+ # usage with FSDP, since full-size gradients will be
762
+ # accumulated on each GPU. It's typically a better tradeoff
763
+ # to do the extra communication with FSDP.
764
+ and not self.is_fsdp
765
+ ):
766
+ return self.model.no_sync()
767
+ else:
768
+ return contextlib.ExitStack() # dummy contextmanager
769
+
770
+ try:
771
+ with maybe_no_sync():
772
+ # forward and backward
773
+ loss, sample_size_i, logging_output = self.task.train_step(
774
+ sample=sample,
775
+ model=self.model,
776
+ criterion=self.criterion,
777
+ optimizer=self.optimizer,
778
+ update_num=self.get_num_updates(),
779
+ ignore_grad=is_dummy_batch,
780
+ **extra_kwargs,
781
+ )
782
+ del loss
783
+
784
+ logging_outputs.append(logging_output)
785
+ sample_size += sample_size_i
786
+
787
+ # emptying the CUDA cache after the first step can
788
+ # reduce the chance of OOM
789
+ if self.cuda and self.get_num_updates() == 0:
790
+ torch.cuda.empty_cache()
791
+ except RuntimeError as e:
792
+ if "out of memory" in str(e):
793
+ self._log_oom(e)
794
+ if raise_oom:
795
+ raise e
796
+ logger.warning(
797
+ "attempting to recover from OOM in forward/backward pass"
798
+ )
799
+ ooms += 1
800
+ self.zero_grad()
801
+ if self.cuda:
802
+ torch.cuda.empty_cache()
803
+ if self.cfg.distributed_training.distributed_world_size == 1:
804
+ return None
805
+ else:
806
+ raise e
807
+
808
+ if self.tpu and i < len(samples) - 1:
809
+ # tpu-comment: every XLA operation before marking step is
810
+ # appended to the IR graph, and processing too many batches
811
+ # before marking step can lead to OOM errors.
812
+ # To handle gradient accumulation use case, we explicitly
813
+ # mark step here for every forward pass without a backward pass
814
+ self._xla_markstep_and_send_to_cpu()
815
+
816
+ if is_dummy_batch:
817
+ if torch.is_tensor(sample_size):
818
+ sample_size.zero_()
819
+ else:
820
+ sample_size *= 0.0
821
+
822
+ if torch.is_tensor(sample_size):
823
+ sample_size = sample_size.float()
824
+ else:
825
+ sample_size = float(sample_size)
826
+
827
+ # gather logging outputs from all replicas
828
+ if self._sync_stats():
829
+ train_time = self._local_cumulative_training_time()
830
+ logging_outputs, (
831
+ sample_size,
832
+ ooms,
833
+ total_train_time,
834
+ ) = self._aggregate_logging_outputs(
835
+ logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch
836
+ )
837
+ self._cumulative_training_time = (
838
+ total_train_time / self.data_parallel_world_size
839
+ )
840
+
841
+ overflow = False
842
+ try:
843
+ with torch.autograd.profiler.record_function("reduce-grads"):
844
+ # reduce gradients across workers
845
+ self.optimizer.all_reduce_grads(self.model)
846
+ if utils.has_parameters(self.criterion):
847
+ self.optimizer.all_reduce_grads(self.criterion)
848
+
849
+ with torch.autograd.profiler.record_function("multiply-grads"):
850
+ # multiply gradients by (data_parallel_size / sample_size) since
851
+ # DDP normalizes by the number of data parallel workers for
852
+ # improved fp16 precision.
853
+ # Thus we get (sum_of_gradients / sample_size) at the end.
854
+ # In case of fp16, this step also undoes loss scaling.
855
+ # (Debugging note: Some optimizers perform this scaling on the
856
+ # fly, so inspecting model.parameters() or optimizer.params may
857
+ # still show the original, unscaled gradients.)
858
+ numer = (
859
+ self.data_parallel_world_size
860
+ if not self.cfg.optimization.use_bmuf or self._sync_stats()
861
+ else 1
862
+ )
863
+ self.optimizer.multiply_grads(numer / (sample_size or 1.0))
864
+ # Note: (sample_size or 1.0) handles the case of a zero gradient, in a
865
+ # way that avoids CPU/device transfers in case sample_size is a GPU or
866
+ # TPU object. The assumption is that the gradient itself is also 0.
867
+
868
+ with torch.autograd.profiler.record_function("clip-grads"):
869
+ # clip grads
870
+ grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm)
871
+
872
+ # check that grad norms are consistent across workers
873
+ # on tpu check tensor is slow
874
+ if not self.tpu:
875
+ if (
876
+ not self.cfg.optimization.use_bmuf
877
+ and self.cfg.distributed_training.ddp_backend != "slow_mo"
878
+ ):
879
+ self._check_grad_norms(grad_norm)
880
+ if not torch.isfinite(grad_norm).all():
881
+ # in case of AMP, if gradients are Nan/Inf then
882
+ # optimizer step is still required
883
+ if self.cfg.common.amp:
884
+ overflow = True
885
+ else:
886
+ # check local gradnorm single GPU case, trigger NanDetector
887
+ raise FloatingPointError("gradients are Nan/Inf")
888
+
889
+ with torch.autograd.profiler.record_function("optimizer"):
890
+ # take an optimization step
891
+ self.task.optimizer_step(
892
+ self.optimizer, model=self.model, update_num=self.get_num_updates()
893
+ )
894
+ if self.cfg.common.amp and overflow:
895
+ if self._amp_retries == self.cfg.common.amp_batch_retries:
896
+ logger.info("AMP: skipping this batch.")
897
+ self._amp_retries = 0
898
+ else:
899
+ self._amp_retries += 1
900
+ return self.train_step(samples, raise_oom) # recursion to feed in same batch
901
+
902
+ except FloatingPointError:
903
+ # re-run the forward and backward pass with hooks attached to print
904
+ # out where it fails
905
+ self.zero_grad()
906
+ with NanDetector(self.get_model()):
907
+ for _, sample in enumerate(samples):
908
+ sample, _ = self._prepare_sample(sample)
909
+ self.task.train_step(
910
+ sample,
911
+ self.model,
912
+ self.criterion,
913
+ self.optimizer,
914
+ self.get_num_updates(),
915
+ ignore_grad=False,
916
+ **extra_kwargs,
917
+ )
918
+ raise
919
+ except OverflowError as e:
920
+ overflow = True
921
+ logger.info(
922
+ f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
923
+ )
924
+ grad_norm = torch.tensor(0.0).cuda()
925
+ self.zero_grad()
926
+ except RuntimeError as e:
927
+ if "out of memory" in str(e):
928
+ self._log_oom(e)
929
+ logger.error("OOM during optimization, irrecoverable")
930
+ raise e
931
+
932
+ # Some distributed wrappers (e.g., SlowMo) need access to the optimizer
933
+ # after the step
934
+ if hasattr(self.model, "perform_additional_optimizer_actions"):
935
+ if hasattr(self.optimizer, "fp32_params"):
936
+ self.model.perform_additional_optimizer_actions(
937
+ self.optimizer.optimizer, self.optimizer.fp32_params
938
+ )
939
+ else:
940
+ self.model.perform_additional_optimizer_actions(
941
+ self.optimizer.optimizer
942
+ )
943
+
944
+ logging_output = None
945
+ if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo":
946
+ self.set_num_updates(self.get_num_updates() + 1)
947
+
948
+ if self.cfg.ema.store_ema:
949
+ # Step EMA forward with new model.
950
+ self.ema.step(
951
+ self.get_model(),
952
+ self.get_num_updates(),
953
+ )
954
+ metrics.log_scalar(
955
+ "ema_decay",
956
+ self.ema.get_decay(),
957
+ priority=10000,
958
+ round=5,
959
+ weight=0,
960
+ )
961
+
962
+ if self.tpu:
963
+ import torch_xla.core.xla_model as xm
964
+
965
+ # mark step on TPUs
966
+ self._xla_markstep_and_send_to_cpu()
967
+
968
+ # only log stats every log_interval steps
969
+ # this causes wps to be misreported when log_interval > 1
970
+ logging_output = {}
971
+ if self.get_num_updates() % self.cfg.common.log_interval == 0:
972
+ # log memory usage
973
+ mem_info = xm.get_memory_info(self.device)
974
+ gb_free = mem_info["kb_free"] / 1024 / 1024
975
+ gb_total = mem_info["kb_total"] / 1024 / 1024
976
+ metrics.log_scalar(
977
+ "gb_free", gb_free, priority=1500, round=1, weight=0
978
+ )
979
+ metrics.log_scalar(
980
+ "gb_total", gb_total, priority=1600, round=1, weight=0
981
+ )
982
+ logging_outputs = self._xla_markstep_and_send_to_cpu(
983
+ logging_outputs
984
+ )
985
+ logging_output = self._reduce_and_log_stats(
986
+ logging_outputs, sample_size, grad_norm
987
+ )
988
+
989
+ # log whenever there's an XLA compilation, since these
990
+ # slow down training and may indicate opportunities for
991
+ # optimization
992
+ self._check_xla_compilation()
993
+ else:
994
+ if self.cuda and self.cuda_env is not None:
995
+ # log minimum free memory over the iteration
996
+ gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
997
+ torch.cuda.reset_peak_memory_stats()
998
+ gb_free = self.cuda_env.total_memory_in_GB - gb_used
999
+ metrics.log_scalar(
1000
+ "gb_free", gb_free, priority=1500, round=1, weight=0
1001
+ )
1002
+
1003
+ # log stats
1004
+ logging_output = self._reduce_and_log_stats(
1005
+ logging_outputs, sample_size, grad_norm
1006
+ )
1007
+
1008
+ # clear CUDA cache to reduce memory fragmentation
1009
+ if (
1010
+ self.cuda
1011
+ and self.cfg.common.empty_cache_freq > 0
1012
+ and (
1013
+ (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1)
1014
+ % self.cfg.common.empty_cache_freq
1015
+ )
1016
+ == 0
1017
+ ):
1018
+ torch.cuda.empty_cache()
1019
+
1020
+ if self.cfg.common.fp16 or self.cfg.common.amp:
1021
+ metrics.log_scalar(
1022
+ "loss_scale",
1023
+ (
1024
+ self.optimizer.scaler.loss_scale
1025
+ if self.cfg.common.fp16
1026
+ else self.optimizer.scaler.get_scale()
1027
+ ),
1028
+ priority=700,
1029
+ round=4,
1030
+ weight=0,
1031
+ )
1032
+
1033
+ metrics.log_stop_time("train_wall")
1034
+ return logging_output
1035
+
1036
+ @metrics.aggregate("valid")
1037
+ def valid_step(self, sample, raise_oom=False):
1038
+ """Do forward pass in evaluation mode."""
1039
+ if self.tpu:
1040
+ import torch_xla.core.xla_model as xm
1041
+
1042
+ xm.rendezvous("valid_step") # wait for all workers
1043
+
1044
+ # If EMA is enabled through store_ema=True
1045
+ # and task.uses_ema is True, pass the EMA model as a keyword
1046
+ # argument to the task.
1047
+ extra_kwargs = {}
1048
+ if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
1049
+ extra_kwargs["ema_model"] = self.ema.get_model()
1050
+
1051
+ with torch.no_grad():
1052
+ self.model.eval()
1053
+ self.criterion.eval()
1054
+
1055
+ sample, is_dummy_batch = self._prepare_sample(sample)
1056
+
1057
+ try:
1058
+ _loss, sample_size, logging_output = self.task.valid_step(
1059
+ sample, self.model, self.criterion, **extra_kwargs
1060
+ )
1061
+ except RuntimeError as e:
1062
+ if "out of memory" in str(e):
1063
+ self._log_oom(e)
1064
+ if not raise_oom:
1065
+ logger.warning(
1066
+ "ran out of memory in validation step, retrying batch"
1067
+ )
1068
+ for p in self.model.parameters():
1069
+ if p.grad is not None:
1070
+ p.grad = None # free some memory
1071
+ if self.cuda:
1072
+ torch.cuda.empty_cache()
1073
+ return self.valid_step(sample, raise_oom=True)
1074
+ raise e
1075
+
1076
+ logging_outputs = [logging_output]
1077
+ if is_dummy_batch:
1078
+ if torch.is_tensor(sample_size):
1079
+ sample_size.zero_()
1080
+ else:
1081
+ sample_size *= 0.0
1082
+
1083
+ # gather logging outputs from all replicas
1084
+ if self.data_parallel_world_size > 1:
1085
+ logging_outputs, (sample_size,) = self._aggregate_logging_outputs(
1086
+ logging_outputs,
1087
+ sample_size,
1088
+ ignore=is_dummy_batch,
1089
+ )
1090
+
1091
+ # log validation stats
1092
+ if self.tpu:
1093
+ logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs)
1094
+ logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
1095
+
1096
+ return logging_output
1097
+
1098
+ def zero_grad(self):
1099
+ self.optimizer.zero_grad()
1100
+
1101
+ def lr_step_begin_epoch(self, epoch):
1102
+ """Adjust the learning rate at the beginning of the epoch."""
1103
+ self.lr_scheduler.step_begin_epoch(epoch)
1104
+ # prefer updating the LR based on the number of steps
1105
+ return self.lr_step_update()
1106
+
1107
+ def lr_reinit(self, total_updates, num_updates):
1108
+ self.lr_scheduler.reinit(total_updates, num_updates)
1109
+
1110
+ def lr_step(self, epoch, val_loss=None):
1111
+ """Adjust the learning rate at the end of the epoch."""
1112
+ self.lr_scheduler.step(epoch, val_loss)
1113
+ # prefer updating the LR based on the number of steps
1114
+ return self.lr_step_update()
1115
+
1116
+ def lr_step_update(self):
1117
+ """Update the learning rate after each update."""
1118
+ new_lr = self.lr_scheduler.step_update(self.get_num_updates())
1119
+ if isinstance(new_lr, dict):
1120
+ for k, v in new_lr.items():
1121
+ metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300)
1122
+ new_lr = new_lr.get("default", next(iter(new_lr.values())))
1123
+ else:
1124
+ metrics.log_scalar("lr", new_lr, weight=0, priority=300)
1125
+ return new_lr
1126
+
1127
+ def get_lr(self):
1128
+ """Get the current learning rate."""
1129
+ return self.optimizer.get_lr()
1130
+
1131
+ def get_model(self):
1132
+ """Get the (non-wrapped) model instance."""
1133
+ return self._model
1134
+
1135
+ def get_criterion(self):
1136
+ """Get the (non-wrapped) criterion instance."""
1137
+ return self._criterion
1138
+
1139
+ def get_meter(self, name):
1140
+ """[deprecated] Get a specific meter by name."""
1141
+ from fairseq import meters
1142
+
1143
+ if "get_meter" not in self._warn_once:
1144
+ self._warn_once.add("get_meter")
1145
+ utils.deprecation_warning(
1146
+ "Trainer.get_meter is deprecated. Please use fairseq.metrics instead."
1147
+ )
1148
+
1149
+ train_meters = metrics.get_meters("train")
1150
+ if train_meters is None:
1151
+ train_meters = {}
1152
+
1153
+ if name == "train_loss" and "loss" in train_meters:
1154
+ return train_meters["loss"]
1155
+ elif name == "train_nll_loss":
1156
+ # support for legacy train.py, which assumed this meter is
1157
+ # always initialized
1158
+ m = train_meters.get("nll_loss", None)
1159
+ return m or meters.AverageMeter()
1160
+ elif name == "wall":
1161
+ # support for legacy train.py, which assumed this meter is
1162
+ # always initialized
1163
+ m = metrics.get_meter("default", "wall")
1164
+ return m or meters.TimeMeter()
1165
+ elif name == "wps":
1166
+ m = metrics.get_meter("train", "wps")
1167
+ return m or meters.TimeMeter()
1168
+ elif name in {"valid_loss", "valid_nll_loss"}:
1169
+ # support for legacy train.py, which assumed these meters
1170
+ # are always initialized
1171
+ k = name[len("valid_") :]
1172
+ m = metrics.get_meter("valid", k)
1173
+ return m or meters.AverageMeter()
1174
+ elif name == "oom":
1175
+ return meters.AverageMeter()
1176
+ elif name in train_meters:
1177
+ return train_meters[name]
1178
+ return None
1179
+
1180
+ def get_num_updates(self):
1181
+ """Get the number of parameters updates."""
1182
+ return self._num_updates
1183
+
1184
+ def set_num_updates(self, num_updates):
1185
+ """Set the number of parameters updates."""
1186
+ self._num_updates = num_updates
1187
+ self.lr_step_update()
1188
+ if self.quantizer:
1189
+ self.quantizer.step_update(self._num_updates)
1190
+ metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
1191
+
1192
+ def clip_grad_norm(self, clip_norm):
1193
+ def agg_norm_fn(total_norm):
1194
+ total_norm = total_norm.cuda().float() ** 2
1195
+ total_norm = distributed_utils.all_reduce(
1196
+ total_norm, group=self.data_parallel_process_group
1197
+ )
1198
+ return total_norm ** 0.5
1199
+
1200
+ should_agg_norm = (
1201
+ self.is_fsdp
1202
+ and (
1203
+ self.data_parallel_process_group is not None
1204
+ or torch.distributed.is_initialized()
1205
+ )
1206
+ )
1207
+ return self.optimizer.clip_grad_norm(
1208
+ clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None
1209
+ )
1210
+
1211
+ def cumulative_training_time(self):
1212
+ if self._cumulative_training_time is None:
1213
+ # single GPU
1214
+ return self._local_cumulative_training_time()
1215
+ else:
1216
+ return self._cumulative_training_time
1217
+
1218
+ def _local_cumulative_training_time(self):
1219
+ """Aggregate training time in seconds."""
1220
+ return time.time() - self._start_time + self._previous_training_time
1221
+
1222
+ def _fp_convert_sample(self, sample):
1223
+ def apply_half(t):
1224
+ if t.dtype is torch.float32:
1225
+ return t.to(dtype=torch.half)
1226
+ return t
1227
+
1228
+ def apply_bfloat16(t):
1229
+ if t.dtype is torch.float32:
1230
+ return t.to(dtype=torch.bfloat16)
1231
+ return t
1232
+
1233
+ if self.cfg.common.fp16:
1234
+ sample = utils.apply_to_sample(apply_half, sample)
1235
+
1236
+ if self.cfg.common.bf16:
1237
+ sample = utils.apply_to_sample(apply_bfloat16, sample)
1238
+
1239
+ return sample
1240
+
1241
+ def _prepare_sample(self, sample, is_dummy=False):
1242
+ if sample == "DUMMY":
1243
+ raise Exception(
1244
+ "Trying to use an uninitialized 'dummy' batch. This usually indicates "
1245
+ "that the total number of batches is smaller than the number of "
1246
+ "participating GPUs. Try reducing the batch size or using fewer GPUs."
1247
+ )
1248
+
1249
+ if sample is None or len(sample) == 0:
1250
+ assert (
1251
+ self._dummy_batch is not None and len(self._dummy_batch) > 0
1252
+ ), "Invalid dummy batch: {}".format(self._dummy_batch)
1253
+ sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True)
1254
+ return sample, True
1255
+
1256
+ # Given that PCIe/NVLink bandwidth is significantly smaller than DRAM bandwidth
1257
+ # it makes sense to do the format conversion on the CPU and then transfer
1258
+ # a smaller buffer to the device. This also saves GPU memory capacity.
1259
+
1260
+ if self.cfg.common.on_cpu_convert_precision:
1261
+ sample = self._fp_convert_sample(sample)
1262
+
1263
+ if self.cuda:
1264
+ if self.pipeline_model_parallel:
1265
+ if 'target' in sample:
1266
+ sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device)
1267
+ else:
1268
+ sample = utils.move_to_cuda(sample)
1269
+ elif self.tpu and is_dummy:
1270
+ # the dummy batch may not be on the appropriate device
1271
+ sample = utils.move_to_cuda(sample, device=self.device)
1272
+
1273
+ if not self.cfg.common.on_cpu_convert_precision:
1274
+ sample = self._fp_convert_sample(sample)
1275
+
1276
+ if self._dummy_batch == "DUMMY":
1277
+ self._dummy_batch = sample
1278
+
1279
+ return sample, False
1280
+
1281
+ def _set_seed(self):
1282
+ # Set seed based on args.seed and the update number so that we get
1283
+ # reproducible results when resuming from checkpoints
1284
+ seed = self.cfg.common.seed + self.get_num_updates()
1285
+ utils.set_torch_seed(seed)
1286
+
1287
+ def _sync_stats(self):
1288
+ # Return True if it's using multiple GPUs and DDP or multiple GPUs with
1289
+ # BMUF and it's a bmuf sync with warmup iterations completed before.
1290
+ if self.data_parallel_world_size == 1:
1291
+ return False
1292
+ elif self.cfg.optimization.use_bmuf:
1293
+ return (
1294
+ self.get_num_updates() + 1
1295
+ ) % self.cfg.bmuf.global_sync_iter == 0 and (
1296
+ self.get_num_updates() + 1
1297
+ ) > self.cfg.bmuf.warmup_iterations
1298
+ else:
1299
+ return True
1300
+
1301
+ def _log_oom(self, exc):
1302
+ msg = "OOM: Ran out of memory with exception: {}".format(exc)
1303
+ logger.warning(msg)
1304
+ if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
1305
+ for device_idx in range(torch.cuda.device_count()):
1306
+ logger.warning(torch.cuda.memory_summary(device=device_idx))
1307
+ sys.stderr.flush()
1308
+
1309
+ def _aggregate_logging_outputs(
1310
+ self,
1311
+ logging_outputs: List[Dict[str, Any]],
1312
+ *extra_stats_to_sum,
1313
+ ignore=False,
1314
+ ):
1315
+ if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()):
1316
+ return self._fast_stat_sync_sum(
1317
+ logging_outputs, *extra_stats_to_sum, ignore=ignore
1318
+ )
1319
+ else:
1320
+ return self._all_gather_list_sync(
1321
+ logging_outputs, *extra_stats_to_sum, ignore=ignore
1322
+ )
1323
+
1324
+ def _all_gather_list_sync(
1325
+ self,
1326
+ logging_outputs: List[Dict[str, Any]],
1327
+ *extra_stats_to_sum,
1328
+ ignore=False,
1329
+ ):
1330
+ """
1331
+ Sync logging outputs across workers. all_gather_list_sync is
1332
+ suitable when logging outputs are complex types.
1333
+ """
1334
+ if self.tpu:
1335
+ raise NotImplementedError
1336
+ if ignore:
1337
+ logging_outputs = []
1338
+ results = list(
1339
+ zip(
1340
+ *distributed_utils.all_gather_list(
1341
+ [logging_outputs] + list(extra_stats_to_sum),
1342
+ max_size=getattr(self.cfg.common, "all_gather_list_size", 16384),
1343
+ group=self.data_parallel_process_group,
1344
+ )
1345
+ )
1346
+ )
1347
+ logging_outputs, extra_stats_to_sum = results[0], results[1:]
1348
+ logging_outputs = list(chain.from_iterable(logging_outputs))
1349
+ extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
1350
+ return logging_outputs, extra_stats_to_sum
1351
+
1352
+ def _fast_stat_sync_sum(
1353
+ self,
1354
+ logging_outputs: List[Dict[str, Any]],
1355
+ *extra_stats_to_sum,
1356
+ ignore=False,
1357
+ ):
1358
+ """
1359
+ Sync logging outputs across workers. fast_stat_sync_sum is
1360
+ faster than all_gather_list_sync, but is only suitable when
1361
+ logging outputs are scalars and can be summed. Note that
1362
+ *logging_outputs* cannot contain any nested dicts/lists.
1363
+ """
1364
+ data = {}
1365
+ for i, stat in enumerate(extra_stats_to_sum):
1366
+ data["extra_stats_" + str(i)] = stat
1367
+ if len(logging_outputs) > 0:
1368
+ log_keys = list(logging_outputs[0].keys())
1369
+ for k in log_keys:
1370
+ if not ignore:
1371
+ v = sum(log[k] for log in logging_outputs if k in log)
1372
+ else:
1373
+ v = logging_outputs[0][k]
1374
+ v = torch.zeros_like(v) if torch.is_tensor(v) else 0
1375
+ data["logging_outputs_" + k] = v
1376
+ else:
1377
+ log_keys = None
1378
+
1379
+ data = distributed_utils.all_reduce_dict(
1380
+ data, device=self.device, group=self.data_parallel_process_group
1381
+ )
1382
+
1383
+ extra_stats_to_sum = [
1384
+ data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum))
1385
+ ]
1386
+ if log_keys is not None:
1387
+ logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}]
1388
+ else:
1389
+ logging_outputs = []
1390
+ return logging_outputs, extra_stats_to_sum
1391
+
1392
+ def _check_grad_norms(self, grad_norm):
1393
+ """Check that grad norms are consistent across workers."""
1394
+ if self._grad_norm_buf is not None:
1395
+ self._grad_norm_buf.zero_()
1396
+ self._grad_norm_buf[self.data_parallel_rank] = grad_norm
1397
+ distributed_utils.all_reduce(
1398
+ self._grad_norm_buf, group=self.data_parallel_process_group
1399
+ )
1400
+
1401
+ def is_consistent(tensor):
1402
+ max_abs_diff = torch.max(torch.abs(tensor - tensor[0]))
1403
+ return (
1404
+ (torch.isfinite(tensor).all()
1405
+ and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all())
1406
+ or
1407
+ (self.cfg.common.amp and not torch.isfinite(tensor).all())
1408
+ # in case of amp non-finite grads are fine
1409
+ )
1410
+
1411
+ if not is_consistent(self._grad_norm_buf):
1412
+ pretty_detail = "\n".join(
1413
+ "rank {:3d} = {:.8f}".format(r, n)
1414
+ for r, n in enumerate(self._grad_norm_buf.tolist())
1415
+ )
1416
+ error_detail = "grad_norm across the workers:\n{}\n".format(
1417
+ pretty_detail
1418
+ )
1419
+ # use FloatingPointError to trigger NanDetector
1420
+ raise FloatingPointError(
1421
+ "Fatal error: gradients are inconsistent between workers. "
1422
+ "Try --ddp-backend=legacy_ddp. "
1423
+ "Or are you mixing up different generation of GPUs in training?"
1424
+ + "\n"
1425
+ + "-" * 80
1426
+ + "\n{}\n".format(error_detail)
1427
+ + "-" * 80
1428
+ )
1429
+
1430
+ def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
1431
+ if grad_norm is not None and (
1432
+ not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm)
1433
+ ):
1434
+ metrics.log_speed("ups", 1.0, priority=100, round=2)
1435
+ metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
1436
+ if self.cfg.optimization.clip_norm > 0:
1437
+ metrics.log_scalar(
1438
+ "clip",
1439
+ torch.where(
1440
+ grad_norm > self.cfg.optimization.clip_norm,
1441
+ grad_norm.new_tensor(100),
1442
+ grad_norm.new_tensor(0),
1443
+ ),
1444
+ priority=500,
1445
+ round=1,
1446
+ )
1447
+
1448
+ with metrics.aggregate() as agg:
1449
+ if logging_outputs is not None:
1450
+ self.task.reduce_metrics(logging_outputs, self.get_criterion())
1451
+ del logging_outputs
1452
+
1453
+ # extra warning for criterions that don't properly log a loss value
1454
+ if "loss" not in agg:
1455
+ if "loss" not in self._warn_once:
1456
+ self._warn_once.add("loss")
1457
+ logger.warning(
1458
+ "Criterion.reduce_metrics did not log a 'loss' value, "
1459
+ "which may break some functionality"
1460
+ )
1461
+ metrics.log_scalar("loss", -1)
1462
+
1463
+ # support legacy interface
1464
+ if self.tpu:
1465
+ logging_output = {}
1466
+ else:
1467
+ logging_output = agg.get_smoothed_values()
1468
+ logging_output["sample_size"] = sample_size
1469
+ for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
1470
+ if key_to_delete in logging_output:
1471
+ del logging_output[key_to_delete]
1472
+ return logging_output
1473
+
1474
+ def _check_xla_compilation(self):
1475
+ import torch_xla.debug.metrics as met
1476
+
1477
+ compile_stats = met.metric_data("CompileTime")
1478
+ if compile_stats is None:
1479
+ return
1480
+ num_xla_compiles = compile_stats[0]
1481
+ if num_xla_compiles > self._num_xla_compiles:
1482
+ logger.warning(
1483
+ "XLA compilation detected on device #{}; too many of these can lead "
1484
+ "to slow training, but we expect a few in the beginning".format(
1485
+ self.cfg.distributed_training.distributed_rank
1486
+ )
1487
+ )
1488
+ self._num_xla_compiles = num_xla_compiles
1489
+
1490
+ def _xla_markstep_and_send_to_cpu(self, data=None):
1491
+ import torch_xla.core.xla_model as xm
1492
+
1493
+ xm.mark_step()
1494
+ if data is not None:
1495
+ from fairseq.utils import xla_device_to_cpu
1496
+
1497
+ return xla_device_to_cpu(data)
1498
+
1499
+
1500
+ def _catalog_shared_params(module, memo=None, prefix=""):
1501
+ if memo is None:
1502
+ first_call = True
1503
+ memo = {}
1504
+ else:
1505
+ first_call = False
1506
+ for name, param in module._parameters.items():
1507
+ param_prefix = prefix + ("." if prefix else "") + name
1508
+ if param not in memo:
1509
+ memo[param] = []
1510
+ memo[param].append(param_prefix)
1511
+ for name, m in module._modules.items():
1512
+ if m is None:
1513
+ continue
1514
+ submodule_prefix = prefix + ("." if prefix else "") + name
1515
+ _catalog_shared_params(m, memo, submodule_prefix)
1516
+ if first_call:
1517
+ return [x for x in memo.values() if len(x) > 1]
1518
+
1519
+
1520
+ def _get_module_by_path(module, path):
1521
+ path = path.split(".")
1522
+ for name in path:
1523
+ module = getattr(module, name)
1524
+ return module
1525
+
1526
+
1527
+ def _set_module_by_path(module, path, value):
1528
+ path = path.split(".")
1529
+ for name in path[:-1]:
1530
+ module = getattr(module, name)
1531
+ setattr(module, path[-1], value)