|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import collections |
|
import collections.abc |
|
import ctypes |
|
import functools |
|
import os |
|
from contextlib import contextmanager |
|
from datetime import timedelta |
|
from typing import Any, Callable, Optional, TypeVar |
|
|
|
import pynvml |
|
import torch |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
from cosmos_transfer1.utils import log |
|
from cosmos_transfer1.utils.ddp_config import DDPConfig |
|
from cosmos_transfer1.utils.device import Device |
|
|
|
try: |
|
from megatron.core import parallel_state |
|
except ImportError: |
|
print("Megatron-core is not installed.") |
|
|
|
T = TypeVar("T") |
|
|
|
|
|
def init() -> int | None: |
|
"""Initialize distributed training.""" |
|
|
|
pynvml.nvmlInit() |
|
local_rank = int(os.getenv("LOCAL_RANK", 0)) |
|
device = Device(local_rank) |
|
os.sched_setaffinity(0, device.get_cpu_affinity()) |
|
|
|
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0" |
|
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" |
|
if dist.is_available(): |
|
if dist.is_initialized(): |
|
return torch.cuda.current_device() |
|
torch.cuda.set_device(local_rank) |
|
|
|
timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800) |
|
|
|
timeout_timedelta = timedelta(seconds=int(timeout_seconds)) |
|
dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) |
|
log.critical( |
|
f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}", |
|
rank0_only=False, |
|
) |
|
|
|
_libcudart = ctypes.CDLL("libcudart.so") |
|
|
|
p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) |
|
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) |
|
_libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) |
|
log.info(f"Training with {get_world_size()} GPUs.") |
|
|
|
|
|
def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: |
|
"""Get the rank (GPU device) of the worker. |
|
|
|
Returns: |
|
rank (int): The rank of the worker. |
|
""" |
|
rank = 0 |
|
if dist.is_available() and dist.is_initialized(): |
|
rank = dist.get_rank(group) |
|
return rank |
|
|
|
|
|
def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: |
|
"""Get world size. How many GPUs are available in this job. |
|
|
|
Returns: |
|
world_size (int): The total number of GPUs available in this job. |
|
""" |
|
world_size = 1 |
|
if dist.is_available() and dist.is_initialized(): |
|
world_size = dist.get_world_size(group) |
|
return world_size |
|
|
|
|
|
def is_rank0() -> bool: |
|
"""Check if current process is the master GPU. |
|
|
|
Returns: |
|
(bool): True if this function is called from the master GPU, else False. |
|
""" |
|
return get_rank() == 0 |
|
|
|
|
|
def rank0_only(func: Callable) -> Callable: |
|
"""Apply this function only to the master GPU. |
|
|
|
Example usage: |
|
@rank0_only |
|
def func(x): |
|
return x + 3 |
|
|
|
Args: |
|
func (Callable): a function. |
|
|
|
Returns: |
|
(Callable): A function wrapper executing the function only on the master GPU. |
|
""" |
|
|
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
if is_rank0(): |
|
return func(*args, **kwargs) |
|
else: |
|
return None |
|
|
|
return wrapper |
|
|
|
|
|
def barrier() -> None: |
|
"""Barrier for all GPUs.""" |
|
if dist.is_available() and dist.is_initialized(): |
|
dist.barrier() |
|
|
|
|
|
def rank0_first(func: Callable) -> Callable: |
|
"""run the function on rank 0 first, then on other ranks.""" |
|
|
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
if is_rank0(): |
|
result = func(*args, **kwargs) |
|
barrier() |
|
if not is_rank0(): |
|
result = func(*args, **kwargs) |
|
return result |
|
|
|
return wrapper |
|
|
|
|
|
class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): |
|
"""This extends torch.nn.parallel.DistributedDataParallel with .training_step(). |
|
|
|
This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that |
|
model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling |
|
model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> |
|
training_step), allowing us to preserve the function names and signatures. |
|
""" |
|
|
|
def __init__(self, model: torch.nn.Module, *args, **kwargs): |
|
super().__init__(model, *args, **kwargs) |
|
|
|
def training_step(self, *args, **kwargs) -> Any: |
|
|
|
original_forward = self.module.forward |
|
|
|
def wrapped_training_step(*_args, **_kwargs): |
|
|
|
self.module.forward = original_forward |
|
|
|
return self.module.training_step(*_args, **_kwargs) |
|
|
|
|
|
self.module.forward = wrapped_training_step |
|
|
|
|
|
return self(*args, **kwargs) |
|
|
|
|
|
def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel: |
|
"""Wraps the model to enable data parallalism for training across multiple GPU devices. |
|
|
|
Args: |
|
config_ddp (DDPConfig): The data parallel config. |
|
model (torch.nn.Module): The PyTorch module. |
|
|
|
Returns: |
|
model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper |
|
if distributed environment is available, otherwise return the original model. |
|
""" |
|
if dist.is_available() and dist.is_initialized(): |
|
local_rank = int(os.getenv("LOCAL_RANK", 0)) |
|
try: |
|
ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) |
|
except Exception as e: |
|
log.info(e) |
|
log.info("parallel_state not initialized, treating all GPUs equally for DDP") |
|
ddp_group = None |
|
|
|
model = DistributedDataParallel( |
|
model, |
|
device_ids=[local_rank], |
|
output_device=local_rank, |
|
find_unused_parameters=config_ddp.find_unused_parameters, |
|
static_graph=config_ddp.static_graph, |
|
broadcast_buffers=config_ddp.broadcast_buffers, |
|
process_group=ddp_group, |
|
) |
|
return model |
|
|
|
|
|
@contextmanager |
|
def ddp_sync_grad(model, enabled): |
|
r""" |
|
Context manager to enable/disable gradient synchronizations across DDP processes for DDP model. |
|
Modified from: |
|
https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync |
|
Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True. |
|
|
|
Within this context, gradients will be accumulated on module |
|
variables, which will later be synchronized in the first |
|
forward-backward pass exiting the context. |
|
|
|
.. warning:: |
|
The forward pass should be included inside the context manager, or |
|
else gradients will still be synchronized. |
|
""" |
|
assert isinstance(model, torch.nn.Module) |
|
if isinstance(model, DistributedDataParallel): |
|
old_require_backward_grad_sync = model.require_backward_grad_sync |
|
if model.static_graph and model.require_backward_grad_sync != enabled: |
|
if model.show_sync_grad_static_graph_warning: |
|
log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.") |
|
model.show_sync_grad_static_graph_warning = False |
|
else: |
|
model.require_backward_grad_sync = enabled |
|
try: |
|
yield |
|
finally: |
|
if isinstance(model, DistributedDataParallel): |
|
model.require_backward_grad_sync = old_require_backward_grad_sync |
|
|
|
|
|
def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: |
|
"""Aggregate the list of data batches from all devices and process the results. |
|
|
|
This is used for gathering validation data batches with utils.dataloader.DistributedEvalSampler. |
|
It will return the data/output of the entire validation set in its original index order. The sizes of data_batches |
|
in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be |
|
created before calling dis.all_gather(). |
|
|
|
Args: |
|
data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where |
|
leaf entries are tensors. |
|
|
|
Returns: |
|
data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where |
|
leaf entries are concatenated tensors. |
|
""" |
|
if isinstance(data_batches[0], torch.Tensor): |
|
|
|
data_concat = torch.cat(data_batches, dim=0) |
|
|
|
max_num_local_samples = torch.tensor(len(data_concat), device="cuda") |
|
dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX) |
|
if len(data_concat) < max_num_local_samples: |
|
assert len(data_concat) + 1 == max_num_local_samples |
|
dummy = torch.empty_like(data_concat[:1]) |
|
data_concat = torch.cat([data_concat, dummy], dim=0) |
|
dummy_count = torch.tensor(1, device="cuda") |
|
else: |
|
dummy_count = torch.tensor(0, device="cuda") |
|
|
|
dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM) |
|
data_concat = all_gather_tensor(data_concat.contiguous()) |
|
data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1) |
|
|
|
if dummy_count > 0: |
|
data_collate = data_collate[:-dummy_count] |
|
elif isinstance(data_batches[0], collections.abc.Mapping): |
|
data_collate = dict() |
|
for key in data_batches[0].keys(): |
|
data_collate[key] = collate_batches([data[key] for data in data_batches]) |
|
else: |
|
raise TypeError |
|
return data_collate |
|
|
|
|
|
@torch.no_grad() |
|
def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]: |
|
"""Gather the corresponding tensor from all GPU devices to a list. |
|
|
|
Args: |
|
tensor (torch.Tensor): Pytorch tensor. |
|
|
|
Returns: |
|
tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices. |
|
""" |
|
tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] |
|
dist.all_gather(tensor_list, tensor) |
|
return tensor_list |
|
|
|
|
|
def broadcast(tensor, src, group=None, async_op=False): |
|
world_size = get_world_size() |
|
if world_size < 2: |
|
return tensor |
|
dist.broadcast(tensor, src=src, group=group, async_op=async_op) |
|
|