# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Train a network across multiple GPUs. """ from fairseq.dataclass.configs import FairseqConfig from fairseq.distributed import utils as distributed_utils from fairseq.trainer import Trainer try: from fairseq.model_parallel.megatron.mpu import ( get_data_parallel_rank, get_data_parallel_world_size, get_model_parallel_src_rank, get_cuda_rng_tracker, ) has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False class MegatronTrainer(Trainer): """Main class for model parallel with data parallel training.""" def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs): if not has_megatron_submodule: raise ImportError( "\n\nPlease install the megatron submodule:" "\n\n git submodule update --init " "fairseq/model_parallel/megatron" ) super().__init__(cfg, task, model, criterion, **kwargs) def clip_grad_norm(self, clip_norm): def _aggregate_model_parallel_grad_norm(total_norm): total_norm = total_norm ** 2 distributed_utils.all_reduce( total_norm, group=distributed_utils.get_model_parallel_group() ) total_norm = total_norm ** 0.5 return total_norm return self.optimizer.clip_grad_norm( clip_norm, aggregate_norm_fn=_aggregate_model_parallel_grad_norm, ) def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" extra_state['rng_tracker_states'] \ = get_cuda_rng_tracker().get_states() super().save_checkpoint(filename, extra_state) def load_checkpoint( self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, ): extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters) if extra_state is not None and 'rng_tracker_states' in extra_state: get_cuda_rng_tracker().set_states( extra_state['rng_tracker_states']) return extra_state