Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from functools import partial | |
from typing import Any, List, Optional, Tuple | |
import torch | |
from torch._C import _disabled_torch_function_impl | |
from torch.fx.experimental.proxy_tensor import ( | |
_ProxyTensor, | |
fetch_tensor_proxy, | |
get_innermost_proxy_mode, | |
get_proxy_slot, | |
set_proxy_slot, | |
track_tensor_tree, | |
) | |
from torch.utils import _pytree as pytree | |
from torch.utils._mode_utils import no_dispatch | |
from torch.utils._pytree import tree_flatten, tree_map, tree_map_only | |
class _CommResult: | |
# a custom type wrapping both inplace output tensor and work handle | |
_tensor: torch.Tensor | |
_work: torch.distributed._Work | |
def _wait_comm(comm_result: _CommResult): | |
# This function is only used by tracing mode as a call_function node right | |
# before consuming a collective result tensor. | |
comm_result._work.wait() | |
return comm_result._tensor | |
def _wrap_comm_result(result: Tuple[Any, Any]) -> Tuple[Any, Any]: | |
def wrap(work, e): | |
assert isinstance(e, torch.Tensor), ( | |
"Excepting collection of tensors as the first element in the " | |
"return value of communication operations." | |
) | |
return _CommResult(e, work) | |
# E.g., | |
# allreduce_ returns ([tensor], work) | |
# allgather_ returns ([[tensor1, tensor2]], work) | |
work = result[1] | |
return (tree_map(partial(wrap, work), result[0]), work) | |
def _get_tracer() -> Optional[torch.fx.Tracer]: | |
mode = get_innermost_proxy_mode() | |
if mode is None: | |
return None | |
return mode.tracer | |
class CommTensor(torch.Tensor): | |
r""" | |
A Tensor subclass to wrap input tensors for collective communications. | |
This Tensor subclass works for both eager and tracing mode. | |
In eager mode, it will record whether the inplace collective communication | |
has been launched using this Tensor and remember the corresponding work | |
handle. If yes, it will explicitly call wait() in the ``__torch_dispatch__`` | |
function before subsequent operations consuming the value of the Tensor. | |
In tracing mode, ``CommTensor`` inserts two node into the graph using the | |
``__torch_dispatch__`` function. | |
1. The first node is inserted right after the | |
communication, wrapping both the inplace output tensor and the returned | |
work handle into a custom ``_CommResult`` type. We have to do this because | |
``ProxyTorchDispatchMode`` only handles ``torch.Tensor``, ``_ProxyTensor``, | |
and ``torch.nn.Parameter`` objects and will treat the work handle | |
as a constant and embed that into the graph. As a result, during execution, | |
it will use the work handle created during tracing and will lead to wrong | |
result. The solution in this test is to manually create a proxy on the | |
return value of ``allreduce_`` which is ``([tensor], work)``, and wrap that | |
to ``[(_CommResult(tensor, work)), work]``. In this way, subsequent nodes can | |
directly consume ``_CommResult``. | |
2. The second node is inserted right before any subsequent node reads from | |
``_CommResult``. It will call ``wait()`` on the stashed work handle to ensure | |
that computation waits for communication. | |
""" | |
_supported_comms: List[str] = [ | |
"_allgather_base_", | |
"_reduce_scatter_base_", | |
"allreduce_", | |
"allgather_", | |
"alltoall_", | |
"broadcast_", | |
"reduce_scatter_", | |
"scatter_", | |
] | |
_tensor: torch.Tensor | |
_work: Optional[torch.distributed._Work] | |
def __new__(cls, tensor: torch.Tensor): | |
t = tensor._tensor if isinstance(tensor, CommTensor) else tensor | |
if get_innermost_proxy_mode() is None: | |
# noop for eager mode | |
return tensor | |
# Use non-CommTensor to avoid nested CommTensor Wrapping | |
r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad) | |
# The tensor object wrapped by this CommTensor | |
# NB: THIS CAN BE A CommTensor; see test_nested_comm_tensor_wrapping | |
r._tensor = tensor # type: ignore[attr-defined] | |
# Record the LAST `work` object returned by collective communication | |
# operations. If this is None, it means no collectives have called | |
# since last time a tensor is wrapped by CommTensor | |
r._work = None # type: ignore[attr-defined] | |
return r | |
def __repr__(self): | |
return f"CommTensor({self._tensor}, work={self._work})" | |
# disable __torch_function__ so that CommTensor can recursively dispatch | |
# with ProxyTorchDispatchMode in make_fx | |
__torch_function__ = _disabled_torch_function_impl | |
def _is_supported(cls, op_name): | |
return any(comm in op_name for comm in cls._supported_comms) | |
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | |
# shared states when unwrapping args | |
tracer: Optional[torch.fx.Tracer] = None | |
work: Optional[torch.distributed._Work] = None | |
# wrapped ._tensor if this is a CommTensor, and insert/call wait() | |
# if communication has been launched on this tensor. | |
def unwrap(e: Any): | |
if isinstance(e, CommTensor): | |
nonlocal tracer, work | |
work = e._work | |
# TODO(ezyang): I don't really understand what's going on | |
# here, but it seems that tracer doesn't reflect whether or | |
# not there is ambient tracing going on, but rather, whether | |
# or not we will trace THIS particular invocation. If we | |
# have a nested CommTensor, the outer layer doesn't actually | |
# trace and we only trace the inner layer | |
if not isinstance(e._tensor, CommTensor): | |
tracer = _get_tracer() | |
if work is not None: | |
if tracer is not None: | |
# insert a node to the traced graph. | |
proxy_res = tracer.create_proxy( # type: ignore[union-attr] | |
"call_function", | |
_wait_comm, | |
(get_proxy_slot(e._tensor, tracer).proxy,), | |
{}, | |
name="wait_comm", | |
) | |
# HACK: update the proxy for the inplace output | |
set_proxy_slot(e._tensor, tracer, proxy_res) | |
# For eager mode, simply wait. | |
# During tracing, still need to wait here, to make sure the | |
# execution during tracing is correct. | |
work.wait() | |
# communication has been waited, stop propagating CommTensor | |
return e._tensor | |
else: | |
return e | |
def wrap(e: Any): | |
return CommTensor(e) if isinstance(e, torch.Tensor) else e | |
def set_work(work: torch.distributed._Work, e: Any): | |
if isinstance(e, CommTensor): | |
e._work = work # type: ignore[attr-defined] | |
elif isinstance(e, torch.Tensor): | |
raise RuntimeError( | |
"Type of output tensors from collective communication during " | |
"tracing should always be CommTensor instead of torch.Tensor" | |
) | |
return e | |
unwrapped_args = tree_map(unwrap, args) | |
unwrapped_kwargs = tree_map(unwrap, kwargs) | |
if cls._is_supported(func.__name__): | |
if tracer is not None: | |
# in tracing mode, get proxies for args | |
proxy_args, proxy_kwargs = tree_map_only( | |
_ProxyTensor, | |
lambda e: e.proxy, | |
tree_map_only( | |
torch.Tensor, | |
fetch_tensor_proxy(tracer), | |
(unwrapped_args, unwrapped_kwargs), | |
), | |
) | |
# get proxy for output tuple | |
proxy_res = func(*proxy_args, **proxy_kwargs) | |
assert isinstance(proxy_res, torch.fx.Proxy) | |
# insert a node that wraps the output tuple into | |
# _CommResult(tensor, work) | |
comm_result_proxy = tracer.create_proxy( # type: ignore[union-attr] | |
"call_function", | |
_wrap_comm_result, | |
(proxy_res,), | |
{}, | |
name="comm_result", | |
) | |
with no_dispatch(): | |
# disable dispatch to avoid trigger ProxyTorchDispatchMode logic | |
out = func(*unwrapped_args, **unwrapped_kwargs) | |
# wrap output with the proxy of _CommResult, so that subsequent | |
# ops and link to it. | |
track_tensor_tree(out, comm_result_proxy, constant=None, tracer=tracer) | |
# N.B.: we still need to remember the work handle here, and wait | |
# for it later to make sure the execution during tracing is | |
# correct. Also, remember comm is already launched | |
# args[0] is always the collection of output tensors | |
pytree.tree_map_(partial(set_work, out[1]), args[0]) | |
# HACK: update the proxy on the input argument as this is an | |
# inplace collective communication. | |
flat_args, args_spec = tree_flatten(unwrapped_args[0]) | |
flat_out, out_spec = tree_flatten(out[0]) | |
for a, o in zip(flat_args, flat_out): | |
set_proxy_slot(a, tracer, get_proxy_slot(o, tracer)) | |
return out | |
else: | |
# in eager mode, simply remember work handle as an attribute | |
out = func(*unwrapped_args, **unwrapped_kwargs) | |
pytree.tree_map_(partial(set_work, out[1]), args[0]) | |
return out | |
else: | |
if work is not None: | |
return func(*unwrapped_args, **unwrapped_kwargs) | |
else: | |
# we need to propagate CommTensor wrapping until the first | |
# subsequent operation has waited for it. | |
return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs)) | |