# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import logging import os import pickle import random import socket import struct import subprocess import warnings import tempfile import uuid from datetime import date from pathlib import Path from collections import OrderedDict from typing import Any, Dict, Mapping import torch import torch.distributed as dist logger = logging.getLogger(__name__) def is_master(args): return args.distributed_rank == 0 def init_distributed_mode(rank, args): if "WORLD_SIZE" in os.environ: args.world_size = int(os.environ["WORLD_SIZE"]) if args.launcher == 'spawn': # single node with multiprocessing.spawn args.world_size = args.num_gpus args.rank = rank args.gpu = rank elif 'RANK' in os.environ: args.rank = int(os.environ["RANK"]) args.gpu = int(os.environ['LOCAL_RANK']) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() if args.world_size == 1: return if 'MASTER_ADDR' in os.environ: args.dist_url = 'tcp://{}:{}'.format(os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) print(f'gpu={args.gpu}, rank={args.rank}, world_size={args.world_size}') args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.distributed.barrier() def gather_list_and_concat(tensor): gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())] dist.all_gather(gather_t, tensor) return torch.cat(gather_t) def get_rank(): return dist.get_rank() def get_world_size(): return dist.get_world_size() def get_default_group(): return dist.group.WORLD def all_gather_list(data, group=None, max_size=16384): """Gathers arbitrary data from all nodes into a list. Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python data. Note that *data* must be picklable. Args: data (Any): data from the local worker to be gathered on other workers group (optional): group of the collective max_size (int, optional): maximum size of the data to be gathered across workers """ rank = get_rank() world_size = get_world_size() buffer_size = max_size * world_size if not hasattr(all_gather_list, '_buffer') or \ all_gather_list._buffer.numel() < buffer_size: all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() buffer = all_gather_list._buffer buffer.zero_() cpu_buffer = all_gather_list._cpu_buffer data = data.cpu() enc = pickle.dumps(data) enc_size = len(enc) header_size = 4 # size of header that contains the length of the encoded data size = header_size + enc_size if size > max_size: raise ValueError('encoded data size ({}) exceeds max_size ({})'.format(size, max_size)) header = struct.pack(">I", enc_size) cpu_buffer[:size] = torch.ByteTensor(list(header + enc)) start = rank * max_size buffer[start:start + size].copy_(cpu_buffer[:size]) all_reduce(buffer, group=group) buffer = buffer.cpu() try: result = [] for i in range(world_size): out_buffer = buffer[i * max_size:(i + 1) * max_size] enc_size, = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) if enc_size > 0: result.append(pickle.loads(bytes(out_buffer[header_size:header_size + enc_size].tolist()))) return result except pickle.UnpicklingError: raise Exception( 'Unable to unpickle data from other workers. all_gather_list requires all ' 'workers to enter the function together, so this error usually indicates ' 'that the workers have fallen out of sync somehow. Workers can fall out of ' 'sync if one of them runs out of memory, or if there are other conditions ' 'in your training script that can cause one worker to finish an epoch ' 'while other workers are still iterating over their portions of the data. ' 'Try rerunning with --ddp-backend=no_c10d and see if that helps.' ) def all_reduce_dict( data: Mapping[str, Any], device, group=None, ) -> Dict[str, Any]: """ AllReduce a dictionary of values across workers. We separately reduce items that are already on the device and items on CPU for better performance. Args: data (Mapping[str, Any]): dictionary of data to all-reduce, but cannot be a nested dictionary device (torch.device): device for the reduction group (optional): group of the collective """ data_keys = list(data.keys()) # We want to separately reduce items that are already on the # device and items on CPU for performance reasons. cpu_data = OrderedDict() device_data = OrderedDict() for k in data_keys: t = data[k] if not torch.is_tensor(t): cpu_data[k] = torch.tensor(t, dtype=torch.double) elif t.device.type != device.type: cpu_data[k] = t.to(dtype=torch.double) else: device_data[k] = t.to(dtype=torch.double) def _all_reduce_dict(data: OrderedDict): if len(data) == 0: return data buf = torch.stack(list(data.values())).to(device=device) all_reduce(buf, group=group) return {k: buf[i] for i, k in enumerate(data)} cpu_data = _all_reduce_dict(cpu_data) device_data = _all_reduce_dict(device_data) def get_from_stack(key): if key in cpu_data: return cpu_data[key] elif key in device_data: return device_data[key] raise KeyError return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) def get_shared_folder() -> Path: user = os.getenv("USER") if Path("/checkpoint/").is_dir(): p = Path(f"/checkpoint/{user}/experiments") p.mkdir(exist_ok=True) return p else: p = Path(f"/tmp/experiments") p.mkdir(exist_ok=True) return p def get_init_file(): # Init file must not exist, but it's parent dir must exist. os.makedirs(str(get_shared_folder()), exist_ok=True) init_file = Path(str(get_shared_folder()) + f"/{uuid.uuid4().hex}_init") if init_file.exists(): os.remove(str(init_file)) return init_file