cavargas10's picture
Upload 10 files
08ab988 verified
raw
history blame
2.63 kB
import os
import io
from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(local_rank)
torch.cuda.set_device(local_rank)
dist.init_process_group('nccl', rank=rank, world_size=world_size)
def read_file_dist(path):
"""
Read the binary file distributedly.
File is only read once by the rank 0 process and broadcasted to other processes.
Returns:
data (io.BytesIO): The binary data read from the file.
"""
if dist.is_initialized() and dist.get_world_size() > 1:
# read file
size = torch.LongTensor(1).cuda()
if dist.get_rank() == 0:
with open(path, 'rb') as f:
data = f.read()
data = torch.ByteTensor(
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
).cuda()
size[0] = data.shape[0]
# broadcast size
dist.broadcast(size, src=0)
if dist.get_rank() != 0:
data = torch.ByteTensor(size[0].item()).cuda()
# broadcast data
dist.broadcast(data, src=0)
# convert to io.BytesIO
data = data.cpu().numpy().tobytes()
data = io.BytesIO(data)
return data
else:
with open(path, 'rb') as f:
data = f.read()
data = io.BytesIO(data)
return data
def unwrap_dist(model):
"""
Unwrap the model from distributed training.
"""
if isinstance(model, DDP):
return model.module
return model
@contextmanager
def master_first():
"""
A context manager that ensures master process executes first.
"""
if not dist.is_initialized():
yield
else:
if dist.get_rank() == 0:
yield
dist.barrier()
else:
dist.barrier()
yield
@contextmanager
def local_master_first():
"""
A context manager that ensures local master process executes first.
"""
if not dist.is_initialized():
yield
else:
if dist.get_rank() % torch.cuda.device_count() == 0:
yield
dist.barrier()
else:
dist.barrier()
yield