Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import collections | |
import io | |
import pickle | |
from typing import Any | |
import torch | |
import torch.distributed as dist | |
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/optim/zero_redundancy_optimizer.py#L29 | |
def broadcast_object( | |
obj: Any, | |
src_rank: int, | |
group: object = dist.group.WORLD, | |
device: torch.device = torch.device("cpu"), | |
) -> Any: | |
r""" | |
Broadcasts an object to the given group. | |
It will be sending the object if called from the source rank and receiving | |
the object otherwise. | |
Arguments: | |
obj: object to broadcast; only used if called on the source rank. | |
src_rank (int): source rank. | |
group (``ProcessGroup``, optional): group used for the broadcast | |
(default: ``dist.group.WORLD``). | |
device (``torch.device``, optional): device to send from or receive | |
to (default: ``torch.device("cpu")``). | |
Returns: | |
The broadcasted object. | |
""" | |
if dist.get_rank() == src_rank: | |
# Send the object | |
buffer = io.BytesIO() | |
torch.save(obj, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL) | |
data = bytearray(buffer.getbuffer()) | |
length_tensor = torch.LongTensor([len(data)]).to(device) | |
data_send_tensor = torch.ByteTensor(data).to(device) | |
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) | |
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) | |
else: | |
# Receive the object | |
length_tensor = torch.LongTensor([0]).to(device) | |
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) | |
data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=device) | |
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) | |
buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) | |
obj = torch.load(buffer, map_location=device, weights_only=False) | |
return obj | |
def _recursive_copy_to_device( | |
value: Any, | |
non_blocking: bool, | |
device: torch.device, | |
) -> Any: | |
r""" | |
Recursively searches lists, tuples, dicts and copies tensors to device if possible. | |
Non-tensor values are passed as-is in the result. | |
.. note: These are all copies, so if there are two objects that reference | |
the same object, then after this call, there will be two different objects | |
referenced on the device. | |
""" | |
if isinstance(value, torch.Tensor): | |
return value.to(device, non_blocking=non_blocking) | |
if isinstance(value, (list, tuple)): | |
values = [_recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for val in value] | |
return values if isinstance(value, list) else tuple(values) | |
if isinstance(value, collections.abc.Mapping): | |
return { | |
key: _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for key, val in value.items() | |
} | |
return value | |