Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
import collections | |
import collections.abc | |
import ctypes | |
import functools | |
import os | |
from contextlib import contextmanager | |
from datetime import timedelta | |
from typing import TYPE_CHECKING, Any, Callable, Container, Optional | |
import pynvml | |
import torch | |
import torch.distributed as dist | |
from torch.distributed import get_process_group_ranks | |
from cosmos_predict1.utils import log | |
from cosmos_predict1.utils.device import Device | |
if TYPE_CHECKING: | |
from cosmos_predict1.utils.config import DDPConfig | |
if dist.is_available(): | |
from torch.distributed.distributed_c10d import _get_default_group | |
from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes | |
try: | |
from megatron.core import parallel_state | |
except ImportError: | |
print("Megatron-core is not installed.") | |
def init() -> int | None: | |
"""Initialize distributed training.""" | |
# Set GPU affinity. | |
pynvml.nvmlInit() | |
local_rank = int(os.getenv("LOCAL_RANK", 0)) | |
device = Device(local_rank) | |
# os.sched_setaffinity(0, device.get_cpu_affinity()) | |
# Set up NCCL communication. | |
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0" | |
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" | |
if dist.is_available(): | |
if dist.is_initialized(): | |
return torch.cuda.current_device() | |
torch.cuda.set_device(local_rank) | |
# Get the timeout value from environment variable | |
timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800) | |
# Convert the timeout to an integer (if it isn't already) and then to a timedelta | |
timeout_timedelta = timedelta(seconds=int(timeout_seconds)) | |
dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) | |
log.critical( | |
f"Initialized distributed program with local rank {local_rank} with timeout {timeout_seconds}", | |
rank0_only=False, | |
) | |
# Increase the L2 fetch granularity for faster speed. | |
_libcudart = ctypes.CDLL("libcudart.so") | |
# Set device limit on the current device. | |
p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) | |
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) | |
_libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) | |
log.info(f"Running with {get_world_size()} GPUs.") | |
def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: | |
"""Get the rank (GPU device) of the worker. | |
Returns: | |
rank (int): The rank of the worker. | |
""" | |
rank = 0 | |
if dist.is_available() and dist.is_initialized(): | |
rank = dist.get_rank(group) | |
return rank | |
def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: | |
"""Get world size. How many GPUs are available in this job. | |
Returns: | |
world_size (int): The total number of GPUs available in this job. | |
""" | |
world_size = 1 | |
if dist.is_available() and dist.is_initialized(): | |
world_size = dist.get_world_size(group) | |
return world_size | |
def is_rank0() -> bool: | |
"""Check if current process is the master GPU. | |
Returns: | |
(bool): True if this function is called from the master GPU, else False. | |
""" | |
return get_rank() == 0 | |
def is_local_rank0() -> bool: | |
"""Check if current process is the local master GPU in the current node. | |
Returns: | |
(bool): True if this function is called from the local master GPU, else False. | |
""" | |
return torch.cuda.current_device() == 0 | |
def device_with_rank(device: str) -> str: | |
"""If the device is 'cuda' and parallelism over GPUs is enabled, returns | |
Otherwise, returns the device as-is.""" | |
if device == 'cuda': | |
return f'cuda:{get_rank()}' | |
return device | |
def rank0_only(func: Callable) -> Callable: | |
"""Apply this function only to the master GPU. | |
Example usage: | |
@rank0_only | |
def func(x): | |
return x + 3 | |
Args: | |
func (Callable): a function. | |
Returns: | |
(Callable): A function wrapper executing the function only on the master GPU. | |
""" | |
def wrapper(*args, **kwargs): # noqa: ANN202 | |
if is_rank0(): | |
return func(*args, **kwargs) | |
else: | |
return None | |
return wrapper | |
def barrier() -> None: | |
"""Barrier for all GPUs.""" | |
if dist.is_available() and dist.is_initialized(): | |
dist.barrier() | |
def rank0_first(func: Callable) -> Callable: | |
"""run the function on rank 0 first, then on other ranks.""" | |
def wrapper(*args, **kwargs): # noqa: ANN202 | |
if is_rank0(): | |
result = func(*args, **kwargs) | |
barrier() | |
if not is_rank0(): | |
result = func(*args, **kwargs) | |
return result | |
return wrapper | |
def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel: | |
"""Wraps the model to enable data parallalism for training across multiple GPU devices. | |
Args: | |
config_ddp (DDPConfig): The data parallel config. | |
model (torch.nn.Module): The PyTorch module. | |
Returns: | |
model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper | |
if distributed environment is available, otherwise return the original model. | |
""" | |
if dist.is_available() and dist.is_initialized(): | |
local_rank = int(os.getenv("LOCAL_RANK", 0)) | |
try: | |
ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) | |
except Exception as e: | |
log.info(e) | |
log.info("parallel_state not initialized, treating all GPUs equally for DDP") | |
ddp_group = None | |
model = DistributedDataParallel( | |
model, | |
device_ids=[local_rank], | |
output_device=local_rank, | |
find_unused_parameters=config_ddp.find_unused_parameters, | |
static_graph=config_ddp.static_graph, | |
broadcast_buffers=config_ddp.broadcast_buffers, | |
process_group=ddp_group, | |
) | |
return model | |
class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): | |
"""This extends torch.nn.parallel.DistributedDataParallel with .training_step(). | |
This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that | |
model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling | |
model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> | |
training_step), allowing us to preserve the function names and signatures. | |
""" | |
def __init__(self, model: torch.nn.Module, *args, **kwargs): | |
super().__init__(model, *args, **kwargs) | |
self.show_sync_grad_static_graph_warning = True | |
def training_step(self, *args, **kwargs) -> Any: | |
# Cache the original model.forward() method. | |
original_forward = self.module.forward | |
def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202 | |
# Unpatch immediately before calling training_step() because itself may want to call the real forward. | |
self.module.forward = original_forward | |
# The actual .training_step(). | |
return self.module.training_step(*_args, **_kwargs) | |
# Patch the original_module's forward so we can redirect the arguments back to the real method. | |
self.module.forward = wrapped_training_step | |
# Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step(). | |
# Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed. | |
return self(*args, **kwargs) | |
def ddp_sync_grad(model, enabled): | |
r""" | |
Context manager to enable/disable gradient synchronizations across DDP processes for DDP model. | |
Modified from: | |
https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync | |
Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True. | |
Within this context, gradients will be accumulated on module | |
variables, which will later be synchronized in the first | |
forward-backward pass exiting the context. | |
.. warning:: | |
The forward pass should be included inside the context manager, or | |
else gradients will still be synchronized. | |
""" | |
assert isinstance(model, torch.nn.Module) | |
if isinstance(model, DistributedDataParallel): | |
old_require_backward_grad_sync = model.require_backward_grad_sync | |
if model.static_graph and model.require_backward_grad_sync != enabled: | |
if model.show_sync_grad_static_graph_warning: | |
log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.") | |
model.show_sync_grad_static_graph_warning = False | |
else: | |
model.require_backward_grad_sync = enabled | |
try: | |
yield | |
finally: | |
if isinstance(model, DistributedDataParallel): | |
model.require_backward_grad_sync = old_require_backward_grad_sync | |
def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: | |
"""Aggregate the list of data batches from all devices and process the results. | |
This is used for gathering validation data batches with utils.dataloader.DistributedEvalSampler. | |
It will return the data/output of the entire validation set in its original index order. The sizes of data_batches | |
in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be | |
created before calling dis.all_gather(). | |
Args: | |
data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where | |
leaf entries are tensors. | |
Returns: | |
data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where | |
leaf entries are concatenated tensors. | |
""" | |
if isinstance(data_batches[0], torch.Tensor): | |
# Concatenate the local data batches. | |
data_concat = torch.cat(data_batches, dim=0) # type: ignore | |
# Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank. | |
max_num_local_samples = torch.tensor(len(data_concat), device="cuda") | |
dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX) | |
if len(data_concat) < max_num_local_samples: | |
assert len(data_concat) + 1 == max_num_local_samples | |
dummy = torch.empty_like(data_concat[:1]) | |
data_concat = torch.cat([data_concat, dummy], dim=0) | |
dummy_count = torch.tensor(1, device="cuda") | |
else: | |
dummy_count = torch.tensor(0, device="cuda") | |
# Get all concatenated batches from all ranks and concatenate again. | |
dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM) | |
data_concat = all_gather_tensor(data_concat.contiguous()) | |
data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1) | |
# Remove the dummy samples. | |
if dummy_count > 0: | |
data_collate = data_collate[:-dummy_count] | |
elif isinstance(data_batches[0], collections.abc.Mapping): | |
data_collate = dict() | |
for key in data_batches[0].keys(): | |
data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore | |
else: | |
raise TypeError | |
return data_collate | |
def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]: | |
"""Gather the corresponding tensor from all GPU devices to a list. | |
Args: | |
tensor (torch.Tensor): Pytorch tensor. | |
Returns: | |
tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices. | |
""" | |
tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] | |
dist.all_gather(tensor_list, tensor) | |
return tensor_list | |
def broadcast(tensor, src, group=None, async_op=False): | |
world_size = get_world_size() | |
if world_size < 2: | |
return tensor | |
dist.broadcast(tensor, src=src, group=group, async_op=async_op) | |
def sync_model_states( | |
model: torch.nn.Module, | |
process_group: Optional[dist.ProcessGroup] = None, | |
src: int = 0, | |
params_and_buffers_to_ignore: Optional[Container[str]] = None, | |
broadcast_buffers: bool = True, | |
): | |
""" | |
Modify based on DDP source code | |
Synchronizes the parameters and buffers of a model across different processes in a distributed setting. | |
This function ensures that all processes in the specified process group have the same initial parameters and | |
buffers from the source rank, typically rank 0. It is useful when different processes start with different model | |
states and a synchronization is required to ensure consistency across all ranks. | |
Args: | |
model (nn.Module): The model whose parameters and buffers are to be synchronized. | |
process_group (dist.ProcessGroup, optional): The process group for communication. If None, | |
the default group is used. Defaults to None. | |
src (int, optional): The source rank from which parameters and buffers will be broadcasted. | |
Defaults to 0. | |
params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer | |
names to exclude from synchronization. Defaults to None, which means all parameters and buffers are | |
included. | |
broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True. | |
Side Effects: | |
This function modifies the state of the model in-place to synchronize it with the source rank's model state. | |
Raises: | |
RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised. | |
Examples: | |
>>> # downloading duplicated model weights from s3 in each rank and save network bandwidth | |
>>> # useful and save our time when model weights are huge | |
>>> if dist.get_rank == 0: | |
>>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path)) | |
>>> dist.barrir() | |
>>> sync_model_states(model) # sync rank0 weights to other ranks | |
""" | |
if process_group is None: | |
process_group = _get_default_group() | |
if not params_and_buffers_to_ignore: | |
params_and_buffers_to_ignore = set() | |
log.info( | |
f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}." | |
) | |
# Build tuple of (module, parameter) for all parameters that require grads. | |
modules_and_parameters = [ | |
(module, parameter) | |
for module_name, module in model.named_modules() | |
for parameter in [ | |
param | |
# Note that we access module.named_parameters instead of | |
# parameters(module). parameters(module) is only needed in the | |
# single-process multi device case, where it accesses replicated | |
# parameters through _former_parameters. | |
for param_name, param in module.named_parameters(recurse=False) | |
if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore | |
# if param.requires_grad | |
# and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore | |
] | |
] | |
# Deduplicate any parameters that might be shared across child modules. | |
memo = set() | |
modules_and_parameters = [ | |
# "p not in memo" is the deduplication check. | |
# "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed. | |
(m, p) | |
for m, p in modules_and_parameters | |
if p not in memo and not memo.add(p) # type: ignore[func-returns-value] | |
] | |
# Build list of parameters. | |
parameters = [parameter for _, parameter in modules_and_parameters] | |
if len(parameters) == 0: | |
return | |
_verify_param_shape_across_processes(process_group, parameters) | |
_sync_module_states( | |
module=model, | |
process_group=process_group, | |
broadcast_bucket_size=int(250 * 1024 * 1024), | |
src=src, | |
params_and_buffers_to_ignore=params_and_buffers_to_ignore, | |
broadcast_buffers=broadcast_buffers, | |
) | |
def dist_reduce_tensor(tensor, rank=0, reduce="mean"): | |
r"""Reduce to rank 0""" | |
world_size = get_world_size() | |
if world_size < 2: | |
return tensor | |
with torch.no_grad(): | |
dist.reduce(tensor, dst=rank) | |
if get_rank() == rank: | |
if reduce == "mean": | |
tensor /= world_size | |
elif reduce == "sum": | |
pass | |
else: | |
raise NotImplementedError | |
return tensor | |