| """ |
| Distributed training utilities |
| """ |
|
|
| import os |
| import torch |
| import torch.distributed as dist |
|
|
|
|
| def setup_distributed(): |
| """Initialize distributed training.""" |
| if 'RANK' in os.environ: |
| rank = int(os.environ['RANK']) |
| world_size = int(os.environ['WORLD_SIZE']) |
| local_rank = int(os.environ['LOCAL_RANK']) |
| else: |
| rank = 0 |
| world_size = 1 |
| local_rank = 0 |
|
|
| if world_size > 1: |
| dist.init_process_group('nccl') |
| torch.cuda.set_device(local_rank) |
|
|
| return rank, world_size, local_rank |
|
|
|
|
| def cleanup_distributed(): |
| """Cleanup distributed training.""" |
| if dist.is_initialized(): |
| dist.destroy_process_group() |
|
|
|
|
| def print_rank0(msg, rank=0): |
| """Print only from rank 0.""" |
| if rank == 0: |
| print(msg) |
|
|
|
|
| def batch_mm_loop(a, b): |
| """ |
| Batch matrix multiply using a loop over the batch dimension. |
| Avoids CUBLAS strided batched routines which have issues on L40S/CUDA 12.8/PyTorch 2.10. |
| |
| Args: |
| a: Tensor of shape (batch, m, k) |
| b: Tensor of shape (batch, k, n) |
| |
| Returns: |
| Tensor of shape (batch, m, n) |
| """ |
| batch = a.shape[0] |
| results = [] |
| for i in range(batch): |
| results.append(torch.mm(a[i], b[i])) |
| return torch.stack(results, dim=0) |
|
|