Adi-69s's picture
Upload 5061 files
b2659ad verified
# Copyright (c) Meta Platforms, Inc. and affiliates
import functools
import operator
from typing import cast, Dict, List, Optional, Sequence, Tuple
import torch
import torch.distributed as dist
import torch.distributed._tensor.api as dtensor
import torch.distributed._tensor.random as random
from torch.distributed._tensor.op_schema import (
_is_inplace_op,
_is_out_variant_op,
OpInfo,
OpSchema,
OutputSpecType,
)
from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta
from torch.distributed._tensor.random import is_rng_supported_mesh
from torch.distributed._tensor.redistribute import redistribute_local_tensor
from torch.distributed._tensor.sharding_prop import ShardingPropagator
from torch.distributed._tensor.tp_conv import (
convolution_backward_handler,
convolution_handler,
)
from torch.distributed.device_mesh import DeviceMesh
try:
from torch.utils import _cxx_pytree as pytree
except ImportError:
from torch.utils import _pytree as pytree # type: ignore[no-redef]
aten = torch.ops.aten
def decompose_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
"""
Decomposes a op to core ATen op, this handler is mostly here
for inference mode usage where the ops are not core aten ops.
"""
r = op_call.decompose(*args, **kwargs)
if r is not NotImplemented:
return r
else:
raise RuntimeError("Decomposition failed")
def is_same_size_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> bool:
lhs = cast(torch.Tensor, args[0])
rhs = cast(torch.Tensor, args[1])
return lhs.shape == rhs.shape
class OpDispatcher:
"""
Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding
propagation, redistribute local args, local compute, and post-processing (re-wrapping). It
also handles any op specific logic if necessary.
"""
def __init__(self) -> None:
self.sharding_propagator = ShardingPropagator()
self._random_ops = {
aten.native_dropout.default,
aten.normal_.default,
aten.rand_like.default,
aten.randn_like.default,
aten.randint_like.default,
aten.randint_like.low_dtype,
aten.randint_like.low_dtype_out,
aten.uniform_.default,
aten.bernoulli.default,
aten.bernoulli_.float,
}
self._custom_op_handlers = {
aten.linear.default: decompose_handler,
aten.is_same_size.default: is_same_size_handler,
aten.convolution.default: convolution_handler,
aten.convolution_backward.default: convolution_backward_handler,
}
def dispatch(
self,
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
"""
Main dispatching logic
"""
# operators that does not need to go through sharding propagation
if op_call in self._custom_op_handlers:
return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
# extract local tensor and sharding infos to a OpInfo
op_info = self.unwrap_to_op_info(op_call, args, kwargs)
self.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
mesh = op_info.mesh
if mesh.get_coordinate() is None:
# For a non-participating device, we do:
# 1. if the return type is scalar, set the local result to None.
# The local results from all devices will then be all-gathered
# and a reduce op will be performed on the list of results
# with appropriate operators:
# for bool type, we by default use AND to reduce;
# we can extend for more ops if necessary.
# 2. if the return type is Tensor or List[Tensor], return empty
# tensor(s) with correct dtype.
spec = output_sharding.output_spec
ret_list = op_info.schema.op._schema.returns
if spec is None:
# For a scalar return type, the non-participating device has None
# as its local result
local_results: object = None
else:
def default_tensor(spec: DTensorSpec) -> torch.Tensor:
if spec.tensor_meta is not None:
shape = spec.tensor_meta.shape
dtype = spec.tensor_meta.dtype
if len(shape) == 0:
# scalar tensor
return torch.zeros((), dtype=dtype)
else:
# non-scalar tensor
return torch.tensor([], dtype=dtype)
else:
raise RuntimeError(f"{spec} has no tensor metadata.")
if isinstance(spec, DTensorSpec):
# return a Tensor value
local_results = default_tensor(spec)
elif isinstance(spec, Sequence):
# return a List[Tensor] value
local_results = [
default_tensor(s) if s is not None else None for s in spec
]
assert isinstance(local_results, List)
if None in local_results:
ret_type = str(ret_list[0].type)
raise NotImplementedError(
f"return type {ret_type} in DTensor op is not supported"
)
else:
if output_sharding.needs_redistribute:
# compute locally with redistribute first if needed
assert output_sharding.schema_suggestions is not None
self.redistribute_local_args(
op_info, output_sharding.schema_suggestions[0]
)
local_tensor_args = (
pytree.tree_unflatten(
cast(List[object], op_info.local_args), op_info.args_tree_spec
)
if op_info.args_tree_spec
else op_info.local_args
)
# run local op computation with potentially modified args/kwargs
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
if op_call in self._random_ops and is_rng_supported_mesh(mesh):
if not random._rng_tracker:
raise RuntimeError(
"A CudaRNGStateTracker instance must be instantiated "
"before executing a random op over a DTensor. "
"Try calling random.manual_seed() or distribute_tensor() "
"before executing a DTensor random op."
)
# For DTensor random operator, run it within a distribute region
with random._rng_tracker._distribute_region(
cast(dtensor.DTensor, args[0])._spec
):
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
else:
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
# communicate the result to all ranks for some operators that return scalar value
if output_sharding.output_spec is None:
if op_call == aten.equal.default:
obj_list = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(obj_list, local_results)
obj_list = list(filter(lambda x: x is not None, obj_list))
# perform reduce on the collection with AND op
local_results = functools.reduce(operator.and_, obj_list, True)
if _is_inplace_op(op_call):
# inplace op should return self instead of re-wrapping
if output_sharding.output_spec is not None:
return args[0]
else:
return None
elif _is_out_variant_op(op_call):
# out variant could possibly have multiple out args (i.e. lu_unpack.out)
output_specs = (
(output_sharding.output_spec,)
if not isinstance(output_sharding.output_spec, tuple)
else output_sharding.output_spec
)
out_dts = []
spec_idx = 0
for argument in op_call._schema.arguments:
if argument.is_out:
out_dt = cast(dtensor.DTensor, kwargs[argument.name])
out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
out_dts.append(out_dt)
spec_idx += 1
assert len(out_dts) >= 1, "out variant should have at least one out arg"
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
else:
return self.wrap(local_results, output_sharding.output_spec)
@staticmethod
def redistribute_local_args(
op_info: OpInfo,
suggested_input_schema: OpSchema,
) -> None:
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
# TODO: the op schema should probably just remain flattened so that we can avoid this tree flatten
# Need to fix all the ops before doing this.
if op_info.args_tree_spec is not None:
flatten_args_schema_to_reshard = tuple(
pytree.tree_leaves(suggested_input_schema.args_schema)
)
else:
flatten_args_schema_to_reshard = suggested_input_schema.args_schema
new_local_args: List[object] = []
for i, arg_spec in enumerate(op_info.flat_args_schema):
reshard_arg_spec = flatten_args_schema_to_reshard[i]
if isinstance(arg_spec, DTensorSpec):
local_tensor = cast(torch.Tensor, op_info.local_args[i])
if arg_spec != reshard_arg_spec:
resharded_local_tensor = redistribute_local_tensor(
local_tensor, arg_spec, reshard_arg_spec
)
new_local_args.append(resharded_local_tensor)
else:
new_local_args.append(local_tensor)
else:
new_local_args.append(reshard_arg_spec)
op_info.local_args = tuple(new_local_args)
def unwrap_to_op_info(
self,
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> OpInfo:
# get runtime schema to determine whether to use pytree to flatten inputs
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
op_call, None
)
if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
# flatten args/kwargs when necessary
tree_args, args_spec = pytree.tree_flatten(args)
args_list: Sequence[object] = tree_args
else:
args_list, args_spec = args, None
args_schema: List[object] = []
kwargs_schema: Dict[str, object] = {}
local_args: List[object] = []
local_kwargs: Dict[str, object] = {}
mesh: Optional[DeviceMesh] = None
for arg in args_list:
if isinstance(arg, dtensor.DTensor):
args_schema.append(arg._spec)
local_args.append(arg._local_tensor)
if mesh is not None:
if mesh != arg.device_mesh:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet!"
)
else:
mesh = arg.device_mesh
elif isinstance(arg, torch.Tensor):
if arg.ndim == 0 and mesh is not None:
# scalar tensor can be safely treated as replicated
args_schema.append(
DTensorSpec(
mesh,
(Replicate(),) * mesh.ndim,
tensor_meta=TensorMeta(
shape=arg.shape, stride=arg.stride(), dtype=arg.dtype
),
)
)
local_args.append(arg)
else:
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
else:
args_schema.append(arg)
local_args.append(arg)
for k, v in kwargs.items():
if isinstance(v, dtensor.DTensor):
kwargs_schema[k] = v._spec
local_kwargs[k] = v._local_tensor
if mesh is not None:
if mesh != v.device_mesh:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet!"
)
else:
mesh = v.device_mesh
elif isinstance(v, torch.Tensor):
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
else:
kwargs_schema[k] = v
local_kwargs[k] = v
assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!"
op_info = OpInfo(
mesh,
OpSchema(
op_call,
pytree.tree_unflatten(args_schema, args_spec)
if args_spec
else tuple(args_schema),
kwargs_schema,
schema_info=runtime_schema_info,
),
args_schema,
tuple(local_args),
local_kwargs,
args_spec,
)
return op_info
@staticmethod
def wrap(res: object, spec: OutputSpecType) -> object:
def to_dt(res, spec):
assert spec is not None and isinstance(
spec, DTensorSpec
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
assert spec.tensor_meta is not None
return dtensor.DTensor(
res,
spec.mesh,
spec.placements,
shape=spec.tensor_meta.shape,
dtype=spec.tensor_meta.dtype,
requires_grad=res.requires_grad,
stride=spec.tensor_meta.stride,
)
if isinstance(res, torch.Tensor):
return to_dt(res, spec)
elif isinstance(res, (list, tuple)):
assert spec is not None and isinstance(
spec, (list, tuple)
), f"output spec does not match with output! Expected list/tuple, got {spec}."
res_list = []
for e, s in zip(res, spec):
# NOTE: local results might return Optional Tensor from ATen op, so we need
# to handle that case and make sure we don't wrap None with DTensor.
# (i.e. native_layer_norm.backward)
if isinstance(e, (list, tuple)) and isinstance(s, (list, tuple)):
res_list.append(type(e)([to_dt(ee, ss) for ee, ss in zip(e, s)]))
elif e is not None and s is not None:
res_list.append(to_dt(e, s))
else:
res_list.append(None) # type: ignore[arg-type]
return tuple(res_list) if isinstance(res, tuple) else res_list
else:
# if the res contains only non tensor values, we simply return it without rewrapping
return res