# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # author: adefossez import logging import os import torch from torch.utils.data.distributed import DistributedSampler from torch.utils.data import DataLoader, Subset from torch.nn.parallel.distributed import DistributedDataParallel logger = logging.getLogger(__name__) rank = 0 world_size = 1 def init(args): """init. Initialize DDP using the given rendezvous file. """ global rank, world_size if args.ddp: assert args.rank is not None and args.world_size is not None rank = args.rank world_size = args.world_size if world_size == 1: return torch.cuda.set_device(rank) torch.distributed.init_process_group( backend=args.ddp_backend, init_method='file://' + os.path.abspath(args.rendezvous_file), world_size=world_size, rank=rank) logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size) def average(metrics, count=1.): """average. Average all the relevant metrices across processes `metrics`should be a 1D float32 fector. Returns the average of `metrics` over all hosts. You can use `count` to control the weight of each worker. """ if world_size == 1: return metrics tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) tensor *= count torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() def wrap(model): """wrap. Wrap a model with DDP if distributed training is enabled. """ if world_size == 1: return model else: return DistributedDataParallel( model, device_ids=[torch.cuda.current_device()], output_device=torch.cuda.current_device()) def barrier(): if world_size > 1: torch.distributed.barrier() def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): """loader. Create a dataloader properly in case of distributed training. If a gradient is going to be computed you must set `shuffle=True`. :param dataset: the dataset to be parallelized :param args: relevant args for the loader :param shuffle: shuffle examples :param klass: loader class :param kwargs: relevant args """ if world_size == 1: return klass(dataset, *args, shuffle=shuffle, **kwargs) if shuffle: # train means we will compute backward, we use DistributedSampler sampler = DistributedSampler(dataset) # We ignore shuffle, DistributedSampler already shuffles return klass(dataset, *args, **kwargs, sampler=sampler) else: # We make a manual shard, as DistributedSampler otherwise replicate some examples dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) return klass(dataset, *args, shuffle=shuffle)