Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import logging | |
import os | |
import pickle | |
import random | |
import socket | |
import struct | |
import subprocess | |
import warnings | |
import tempfile | |
import uuid | |
from datetime import date | |
from pathlib import Path | |
from collections import OrderedDict | |
from typing import Any, Dict, Mapping | |
import torch | |
import torch.distributed as dist | |
logger = logging.getLogger(__name__) | |
def is_master(args): | |
return args.distributed_rank == 0 | |
def init_distributed_mode(rank, args): | |
if "WORLD_SIZE" in os.environ: | |
args.world_size = int(os.environ["WORLD_SIZE"]) | |
if args.launcher == 'spawn': # single node with multiprocessing.spawn | |
args.world_size = args.num_gpus | |
args.rank = rank | |
args.gpu = rank | |
elif 'RANK' in os.environ: | |
args.rank = int(os.environ["RANK"]) | |
args.gpu = int(os.environ['LOCAL_RANK']) | |
elif 'SLURM_PROCID' in os.environ: | |
args.rank = int(os.environ['SLURM_PROCID']) | |
args.gpu = args.rank % torch.cuda.device_count() | |
if args.world_size == 1: | |
return | |
if 'MASTER_ADDR' in os.environ: | |
args.dist_url = 'tcp://{}:{}'.format(os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) | |
print(f'gpu={args.gpu}, rank={args.rank}, world_size={args.world_size}') | |
args.distributed = True | |
torch.cuda.set_device(args.gpu) | |
args.dist_backend = 'nccl' | |
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) | |
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, | |
world_size=args.world_size, rank=args.rank) | |
torch.distributed.barrier() | |
def gather_list_and_concat(tensor): | |
gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())] | |
dist.all_gather(gather_t, tensor) | |
return torch.cat(gather_t) | |
def get_rank(): | |
return dist.get_rank() | |
def get_world_size(): | |
return dist.get_world_size() | |
def get_default_group(): | |
return dist.group.WORLD | |
def all_gather_list(data, group=None, max_size=16384): | |
"""Gathers arbitrary data from all nodes into a list. | |
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python | |
data. Note that *data* must be picklable. | |
Args: | |
data (Any): data from the local worker to be gathered on other workers | |
group (optional): group of the collective | |
max_size (int, optional): maximum size of the data to be gathered | |
across workers | |
""" | |
rank = get_rank() | |
world_size = get_world_size() | |
buffer_size = max_size * world_size | |
if not hasattr(all_gather_list, '_buffer') or \ | |
all_gather_list._buffer.numel() < buffer_size: | |
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) | |
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() | |
buffer = all_gather_list._buffer | |
buffer.zero_() | |
cpu_buffer = all_gather_list._cpu_buffer | |
data = data.cpu() | |
enc = pickle.dumps(data) | |
enc_size = len(enc) | |
header_size = 4 # size of header that contains the length of the encoded data | |
size = header_size + enc_size | |
if size > max_size: | |
raise ValueError('encoded data size ({}) exceeds max_size ({})'.format(size, max_size)) | |
header = struct.pack(">I", enc_size) | |
cpu_buffer[:size] = torch.ByteTensor(list(header + enc)) | |
start = rank * max_size | |
buffer[start:start + size].copy_(cpu_buffer[:size]) | |
all_reduce(buffer, group=group) | |
buffer = buffer.cpu() | |
try: | |
result = [] | |
for i in range(world_size): | |
out_buffer = buffer[i * max_size:(i + 1) * max_size] | |
enc_size, = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) | |
if enc_size > 0: | |
result.append(pickle.loads(bytes(out_buffer[header_size:header_size + enc_size].tolist()))) | |
return result | |
except pickle.UnpicklingError: | |
raise Exception( | |
'Unable to unpickle data from other workers. all_gather_list requires all ' | |
'workers to enter the function together, so this error usually indicates ' | |
'that the workers have fallen out of sync somehow. Workers can fall out of ' | |
'sync if one of them runs out of memory, or if there are other conditions ' | |
'in your training script that can cause one worker to finish an epoch ' | |
'while other workers are still iterating over their portions of the data. ' | |
'Try rerunning with --ddp-backend=no_c10d and see if that helps.' | |
) | |
def all_reduce_dict( | |
data: Mapping[str, Any], | |
device, | |
group=None, | |
) -> Dict[str, Any]: | |
""" | |
AllReduce a dictionary of values across workers. We separately | |
reduce items that are already on the device and items on CPU for | |
better performance. | |
Args: | |
data (Mapping[str, Any]): dictionary of data to all-reduce, but | |
cannot be a nested dictionary | |
device (torch.device): device for the reduction | |
group (optional): group of the collective | |
""" | |
data_keys = list(data.keys()) | |
# We want to separately reduce items that are already on the | |
# device and items on CPU for performance reasons. | |
cpu_data = OrderedDict() | |
device_data = OrderedDict() | |
for k in data_keys: | |
t = data[k] | |
if not torch.is_tensor(t): | |
cpu_data[k] = torch.tensor(t, dtype=torch.double) | |
elif t.device.type != device.type: | |
cpu_data[k] = t.to(dtype=torch.double) | |
else: | |
device_data[k] = t.to(dtype=torch.double) | |
def _all_reduce_dict(data: OrderedDict): | |
if len(data) == 0: | |
return data | |
buf = torch.stack(list(data.values())).to(device=device) | |
all_reduce(buf, group=group) | |
return {k: buf[i] for i, k in enumerate(data)} | |
cpu_data = _all_reduce_dict(cpu_data) | |
device_data = _all_reduce_dict(device_data) | |
def get_from_stack(key): | |
if key in cpu_data: | |
return cpu_data[key] | |
elif key in device_data: | |
return device_data[key] | |
raise KeyError | |
return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) | |
def get_shared_folder() -> Path: | |
user = os.getenv("USER") | |
if Path("/checkpoint/").is_dir(): | |
p = Path(f"/checkpoint/{user}/experiments") | |
p.mkdir(exist_ok=True) | |
return p | |
else: | |
p = Path(f"/tmp/experiments") | |
p.mkdir(exist_ok=True) | |
return p | |
def get_init_file(): | |
# Init file must not exist, but it's parent dir must exist. | |
os.makedirs(str(get_shared_folder()), exist_ok=True) | |
init_file = Path(str(get_shared_folder()) + f"/{uuid.uuid4().hex}_init") | |
if init_file.exists(): | |
os.remove(str(init_file)) | |
return init_file | |