Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates | |
import functools | |
import operator | |
from typing import cast, Dict, List, Optional, Sequence, Tuple | |
import torch | |
import torch.distributed as dist | |
import torch.distributed._tensor.api as dtensor | |
import torch.distributed._tensor.random as random | |
from torch.distributed._tensor.op_schema import ( | |
_is_inplace_op, | |
_is_out_variant_op, | |
OpInfo, | |
OpSchema, | |
OutputSpecType, | |
) | |
from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta | |
from torch.distributed._tensor.random import is_rng_supported_mesh | |
from torch.distributed._tensor.redistribute import redistribute_local_tensor | |
from torch.distributed._tensor.sharding_prop import ShardingPropagator | |
from torch.distributed._tensor.tp_conv import ( | |
convolution_backward_handler, | |
convolution_handler, | |
) | |
from torch.distributed.device_mesh import DeviceMesh | |
try: | |
from torch.utils import _cxx_pytree as pytree | |
except ImportError: | |
from torch.utils import _pytree as pytree # type: ignore[no-redef] | |
aten = torch.ops.aten | |
def decompose_handler( | |
op_call: torch._ops.OpOverload, | |
args: Tuple[object, ...], | |
kwargs: Dict[str, object], | |
) -> object: | |
""" | |
Decomposes a op to core ATen op, this handler is mostly here | |
for inference mode usage where the ops are not core aten ops. | |
""" | |
r = op_call.decompose(*args, **kwargs) | |
if r is not NotImplemented: | |
return r | |
else: | |
raise RuntimeError("Decomposition failed") | |
def is_same_size_handler( | |
op_call: torch._ops.OpOverload, | |
args: Tuple[object, ...], | |
kwargs: Dict[str, object], | |
) -> bool: | |
lhs = cast(torch.Tensor, args[0]) | |
rhs = cast(torch.Tensor, args[1]) | |
return lhs.shape == rhs.shape | |
class OpDispatcher: | |
""" | |
Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding | |
propagation, redistribute local args, local compute, and post-processing (re-wrapping). It | |
also handles any op specific logic if necessary. | |
""" | |
def __init__(self) -> None: | |
self.sharding_propagator = ShardingPropagator() | |
self._random_ops = { | |
aten.native_dropout.default, | |
aten.normal_.default, | |
aten.rand_like.default, | |
aten.randn_like.default, | |
aten.randint_like.default, | |
aten.randint_like.low_dtype, | |
aten.randint_like.low_dtype_out, | |
aten.uniform_.default, | |
aten.bernoulli.default, | |
aten.bernoulli_.float, | |
} | |
self._custom_op_handlers = { | |
aten.linear.default: decompose_handler, | |
aten.is_same_size.default: is_same_size_handler, | |
aten.convolution.default: convolution_handler, | |
aten.convolution_backward.default: convolution_backward_handler, | |
} | |
def dispatch( | |
self, | |
op_call: torch._ops.OpOverload, | |
args: Tuple[object, ...], | |
kwargs: Dict[str, object], | |
) -> object: | |
""" | |
Main dispatching logic | |
""" | |
# operators that does not need to go through sharding propagation | |
if op_call in self._custom_op_handlers: | |
return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] | |
# extract local tensor and sharding infos to a OpInfo | |
op_info = self.unwrap_to_op_info(op_call, args, kwargs) | |
self.sharding_propagator.propagate(op_info) | |
output_sharding = op_info.output_sharding | |
assert output_sharding is not None, "output sharding should not be None" | |
mesh = op_info.mesh | |
if mesh.get_coordinate() is None: | |
# For a non-participating device, we do: | |
# 1. if the return type is scalar, set the local result to None. | |
# The local results from all devices will then be all-gathered | |
# and a reduce op will be performed on the list of results | |
# with appropriate operators: | |
# for bool type, we by default use AND to reduce; | |
# we can extend for more ops if necessary. | |
# 2. if the return type is Tensor or List[Tensor], return empty | |
# tensor(s) with correct dtype. | |
spec = output_sharding.output_spec | |
ret_list = op_info.schema.op._schema.returns | |
if spec is None: | |
# For a scalar return type, the non-participating device has None | |
# as its local result | |
local_results: object = None | |
else: | |
def default_tensor(spec: DTensorSpec) -> torch.Tensor: | |
if spec.tensor_meta is not None: | |
shape = spec.tensor_meta.shape | |
dtype = spec.tensor_meta.dtype | |
if len(shape) == 0: | |
# scalar tensor | |
return torch.zeros((), dtype=dtype) | |
else: | |
# non-scalar tensor | |
return torch.tensor([], dtype=dtype) | |
else: | |
raise RuntimeError(f"{spec} has no tensor metadata.") | |
if isinstance(spec, DTensorSpec): | |
# return a Tensor value | |
local_results = default_tensor(spec) | |
elif isinstance(spec, Sequence): | |
# return a List[Tensor] value | |
local_results = [ | |
default_tensor(s) if s is not None else None for s in spec | |
] | |
assert isinstance(local_results, List) | |
if None in local_results: | |
ret_type = str(ret_list[0].type) | |
raise NotImplementedError( | |
f"return type {ret_type} in DTensor op is not supported" | |
) | |
else: | |
if output_sharding.needs_redistribute: | |
# compute locally with redistribute first if needed | |
assert output_sharding.schema_suggestions is not None | |
self.redistribute_local_args( | |
op_info, output_sharding.schema_suggestions[0] | |
) | |
local_tensor_args = ( | |
pytree.tree_unflatten( | |
cast(List[object], op_info.local_args), op_info.args_tree_spec | |
) | |
if op_info.args_tree_spec | |
else op_info.local_args | |
) | |
# run local op computation with potentially modified args/kwargs | |
local_tensor_args = cast(Tuple[object, ...], local_tensor_args) | |
if op_call in self._random_ops and is_rng_supported_mesh(mesh): | |
if not random._rng_tracker: | |
raise RuntimeError( | |
"A CudaRNGStateTracker instance must be instantiated " | |
"before executing a random op over a DTensor. " | |
"Try calling random.manual_seed() or distribute_tensor() " | |
"before executing a DTensor random op." | |
) | |
# For DTensor random operator, run it within a distribute region | |
with random._rng_tracker._distribute_region( | |
cast(dtensor.DTensor, args[0])._spec | |
): | |
local_results = op_call(*local_tensor_args, **op_info.local_kwargs) | |
else: | |
local_results = op_call(*local_tensor_args, **op_info.local_kwargs) | |
# communicate the result to all ranks for some operators that return scalar value | |
if output_sharding.output_spec is None: | |
if op_call == aten.equal.default: | |
obj_list = [None for _ in range(dist.get_world_size())] | |
dist.all_gather_object(obj_list, local_results) | |
obj_list = list(filter(lambda x: x is not None, obj_list)) | |
# perform reduce on the collection with AND op | |
local_results = functools.reduce(operator.and_, obj_list, True) | |
if _is_inplace_op(op_call): | |
# inplace op should return self instead of re-wrapping | |
if output_sharding.output_spec is not None: | |
return args[0] | |
else: | |
return None | |
elif _is_out_variant_op(op_call): | |
# out variant could possibly have multiple out args (i.e. lu_unpack.out) | |
output_specs = ( | |
(output_sharding.output_spec,) | |
if not isinstance(output_sharding.output_spec, tuple) | |
else output_sharding.output_spec | |
) | |
out_dts = [] | |
spec_idx = 0 | |
for argument in op_call._schema.arguments: | |
if argument.is_out: | |
out_dt = cast(dtensor.DTensor, kwargs[argument.name]) | |
out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) | |
out_dts.append(out_dt) | |
spec_idx += 1 | |
assert len(out_dts) >= 1, "out variant should have at least one out arg" | |
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] | |
else: | |
return self.wrap(local_results, output_sharding.output_spec) | |
def redistribute_local_args( | |
op_info: OpInfo, | |
suggested_input_schema: OpSchema, | |
) -> None: | |
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it | |
# TODO: the op schema should probably just remain flattened so that we can avoid this tree flatten | |
# Need to fix all the ops before doing this. | |
if op_info.args_tree_spec is not None: | |
flatten_args_schema_to_reshard = tuple( | |
pytree.tree_leaves(suggested_input_schema.args_schema) | |
) | |
else: | |
flatten_args_schema_to_reshard = suggested_input_schema.args_schema | |
new_local_args: List[object] = [] | |
for i, arg_spec in enumerate(op_info.flat_args_schema): | |
reshard_arg_spec = flatten_args_schema_to_reshard[i] | |
if isinstance(arg_spec, DTensorSpec): | |
local_tensor = cast(torch.Tensor, op_info.local_args[i]) | |
if arg_spec != reshard_arg_spec: | |
resharded_local_tensor = redistribute_local_tensor( | |
local_tensor, arg_spec, reshard_arg_spec | |
) | |
new_local_args.append(resharded_local_tensor) | |
else: | |
new_local_args.append(local_tensor) | |
else: | |
new_local_args.append(reshard_arg_spec) | |
op_info.local_args = tuple(new_local_args) | |
def unwrap_to_op_info( | |
self, | |
op_call: torch._ops.OpOverload, | |
args: Tuple[object, ...], | |
kwargs: Dict[str, object], | |
) -> OpInfo: | |
# get runtime schema to determine whether to use pytree to flatten inputs | |
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( | |
op_call, None | |
) | |
if runtime_schema_info is not None and runtime_schema_info.needs_pytree: | |
# flatten args/kwargs when necessary | |
tree_args, args_spec = pytree.tree_flatten(args) | |
args_list: Sequence[object] = tree_args | |
else: | |
args_list, args_spec = args, None | |
args_schema: List[object] = [] | |
kwargs_schema: Dict[str, object] = {} | |
local_args: List[object] = [] | |
local_kwargs: Dict[str, object] = {} | |
mesh: Optional[DeviceMesh] = None | |
for arg in args_list: | |
if isinstance(arg, dtensor.DTensor): | |
args_schema.append(arg._spec) | |
local_args.append(arg._local_tensor) | |
if mesh is not None: | |
if mesh != arg.device_mesh: | |
raise NotImplementedError( | |
f"{op_call}: DTensor does not support cross-mesh operation yet!" | |
) | |
else: | |
mesh = arg.device_mesh | |
elif isinstance(arg, torch.Tensor): | |
if arg.ndim == 0 and mesh is not None: | |
# scalar tensor can be safely treated as replicated | |
args_schema.append( | |
DTensorSpec( | |
mesh, | |
(Replicate(),) * mesh.ndim, | |
tensor_meta=TensorMeta( | |
shape=arg.shape, stride=arg.stride(), dtype=arg.dtype | |
), | |
) | |
) | |
local_args.append(arg) | |
else: | |
raise RuntimeError( | |
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" | |
" torch.Tensor to DTensor before calling distributed operators!" | |
) | |
else: | |
args_schema.append(arg) | |
local_args.append(arg) | |
for k, v in kwargs.items(): | |
if isinstance(v, dtensor.DTensor): | |
kwargs_schema[k] = v._spec | |
local_kwargs[k] = v._local_tensor | |
if mesh is not None: | |
if mesh != v.device_mesh: | |
raise NotImplementedError( | |
f"{op_call}: DTensor does not support cross-mesh operation yet!" | |
) | |
else: | |
mesh = v.device_mesh | |
elif isinstance(v, torch.Tensor): | |
raise RuntimeError( | |
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" | |
" torch.Tensor to DTensor before calling distributed operators!" | |
) | |
else: | |
kwargs_schema[k] = v | |
local_kwargs[k] = v | |
assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!" | |
op_info = OpInfo( | |
mesh, | |
OpSchema( | |
op_call, | |
pytree.tree_unflatten(args_schema, args_spec) | |
if args_spec | |
else tuple(args_schema), | |
kwargs_schema, | |
schema_info=runtime_schema_info, | |
), | |
args_schema, | |
tuple(local_args), | |
local_kwargs, | |
args_spec, | |
) | |
return op_info | |
def wrap(res: object, spec: OutputSpecType) -> object: | |
def to_dt(res, spec): | |
assert spec is not None and isinstance( | |
spec, DTensorSpec | |
), f"output spec does not match with output! Expected DTensorSpec, got {spec}." | |
assert spec.tensor_meta is not None | |
return dtensor.DTensor( | |
res, | |
spec.mesh, | |
spec.placements, | |
shape=spec.tensor_meta.shape, | |
dtype=spec.tensor_meta.dtype, | |
requires_grad=res.requires_grad, | |
stride=spec.tensor_meta.stride, | |
) | |
if isinstance(res, torch.Tensor): | |
return to_dt(res, spec) | |
elif isinstance(res, (list, tuple)): | |
assert spec is not None and isinstance( | |
spec, (list, tuple) | |
), f"output spec does not match with output! Expected list/tuple, got {spec}." | |
res_list = [] | |
for e, s in zip(res, spec): | |
# NOTE: local results might return Optional Tensor from ATen op, so we need | |
# to handle that case and make sure we don't wrap None with DTensor. | |
# (i.e. native_layer_norm.backward) | |
if isinstance(e, (list, tuple)) and isinstance(s, (list, tuple)): | |
res_list.append(type(e)([to_dt(ee, ss) for ee, ss in zip(e, s)])) | |
elif e is not None and s is not None: | |
res_list.append(to_dt(e, s)) | |
else: | |
res_list.append(None) # type: ignore[arg-type] | |
return tuple(res_list) if isinstance(res, tuple) else res_list | |
else: | |
# if the res contains only non tensor values, we simply return it without rewrapping | |
return res | |