Spaces:
Runtime error
Runtime error
File size: 2,424 Bytes
4a4ed11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
"""
Helpers for distributed training.
"""
import io
import os
import socket
import blobfile as bf
from mpi4py import MPI
import torch as th
import torch.distributed as dist
# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 8
SETUP_RETRY_COUNT = 3
def setup_dist():
"""
Setup a distributed process group.
"""
if dist.is_initialized():
return
os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
comm = MPI.COMM_WORLD
backend = "gloo" if not th.cuda.is_available() else "nccl"
if backend == "gloo":
hostname = "localhost"
else:
hostname = socket.gethostbyname(socket.getfqdn())
os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
os.environ["RANK"] = str(comm.rank)
os.environ["WORLD_SIZE"] = str(comm.size)
port = comm.bcast(_find_free_port(), root=0)
os.environ["MASTER_PORT"] = str(port)
dist.init_process_group(backend=backend, init_method="env://")
def dev():
"""
Get the device to use for torch.distributed.
"""
if th.cuda.is_available():
return th.device(f"cuda")
return th.device("cpu")
def load_state_dict(path, **kwargs):
"""
Load a PyTorch file without redundant fetches across MPI ranks.
"""
chunk_size = 2 ** 30 # MPI has a relatively small size limit
if MPI.COMM_WORLD.Get_rank() == 0:
with bf.BlobFile(path, "rb") as f:
data = f.read()
num_chunks = len(data) // chunk_size
if len(data) % chunk_size:
num_chunks += 1
MPI.COMM_WORLD.bcast(num_chunks)
for i in range(0, len(data), chunk_size):
MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
else:
num_chunks = MPI.COMM_WORLD.bcast(None)
data = bytes()
for _ in range(num_chunks):
data += MPI.COMM_WORLD.bcast(None)
return th.load(io.BytesIO(data), **kwargs)
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""
for p in params:
with th.no_grad():
dist.broadcast(p, 0)
def _find_free_port():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
finally:
s.close()
|