File size: 4,918 Bytes
491eded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
"""
Distributed Training Utilities
This file contains utility functions for distributed training with PyTorch.
It provides tools for setting up distributed environments, efficient file handling
across processes, model unwrapping, and synchronization mechanisms to coordinate
execution across multiple GPUs and nodes.
"""
import os
import io
from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
"""
Set up the distributed training environment.
Args:
rank (int): Global rank of the current process
local_rank (int): Local rank of the current process on this node
world_size (int): Total number of processes in the distributed training
master_addr (str): IP address of the master node
master_port (str): Port on the master node for communication
"""
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(local_rank)
# Set the device for the current process
torch.cuda.set_device(local_rank)
# Initialize the process group for distributed communication
dist.init_process_group('nccl', rank=rank, world_size=world_size)
def read_file_dist(path):
"""
Read the binary file distributedly.
File is only read once by the rank 0 process and broadcasted to other processes.
This reduces I/O overhead in distributed training.
Args:
path (str): Path to the file to be read
Returns:
data (io.BytesIO): The binary data read from the file.
"""
if dist.is_initialized() and dist.get_world_size() > 1:
# Prepare tensor to store file size
size = torch.LongTensor(1).cuda()
if dist.get_rank() == 0:
# Master process reads the file
with open(path, 'rb') as f:
data = f.read()
# Convert binary data to CUDA tensor for broadcasting
data = torch.ByteTensor(
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
).cuda()
size[0] = data.shape[0]
# Broadcast file size to all processes
dist.broadcast(size, src=0)
if dist.get_rank() != 0:
# Non-master processes allocate buffer for receiving data
data = torch.ByteTensor(size[0].item()).cuda()
# Broadcast actual file data to all processes
dist.broadcast(data, src=0)
# Convert tensor back to binary data
data = data.cpu().numpy().tobytes()
data = io.BytesIO(data)
return data
else:
# For non-distributed or single-process case, just read directly
with open(path, 'rb') as f:
data = f.read()
data = io.BytesIO(data)
return data
def unwrap_dist(model):
"""
Unwrap the model from distributed training wrapper.
Args:
model: A potentially wrapped PyTorch model
Returns:
The underlying model without DistributedDataParallel wrapper
"""
if isinstance(model, DDP):
return model.module
return model
@contextmanager
def master_first():
"""
A context manager that ensures master process (rank 0) executes first.
All other processes wait for the master to finish before proceeding.
Usage:
with master_first():
# Code that should execute in master first, then others
"""
if not dist.is_initialized():
# If not in distributed mode, just execute normally
yield
else:
if dist.get_rank() == 0:
# Master process executes the code
yield
# Signal completion to other processes
dist.barrier()
else:
# Other processes wait for master to finish
dist.barrier()
# Then execute the code
yield
@contextmanager
def local_master_first():
"""
A context manager that ensures local master process (first process on each node)
executes first. Other processes on the same node wait before proceeding.
Usage:
with local_master_first():
# Code that should execute in local master first, then others
"""
if not dist.is_initialized():
# If not in distributed mode, just execute normally
yield
else:
if dist.get_rank() % torch.cuda.device_count() == 0:
# Local master process executes the code
yield
# Signal completion to other processes
dist.barrier()
else:
# Other processes wait for local master to finish
dist.barrier()
# Then execute the code
yield |