import torch import torch.utils._pytree as pytree from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet from torch._dispatch.python import suspend_functionalization from torch._functorch.aot_autograd import AOTConfig, create_joint from torch._functorch.eager_transforms import ( _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize, ) from torch._higher_order_ops.cond import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, UnsupportedAliasMutationException, ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, make_fx, ProxyTorchDispatchMode, track_tensor_tree, ) from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._python_dispatch import ( _get_current_dispatch_mode, _pop_mode_temporarily, ) # TODO: We add this to prevent dymamo from tracing into map_wrapper, # remove the wrapper call when it's ready. class MapWrapper(HigherOrderOperator): def __call__(self, xs, *args): return map_wrapper(xs, *args) map = MapWrapper("map", _deprecated_global_ns=True) map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True) dummy_aot_config = AOTConfig( fw_compiler=None, bw_compiler=None, partition_fn=None, decompositions={}, num_params_buffers=0, aot_id=0, keep_inference_input_mutations=False, ) def create_fw_bw_graph(f, num_mapped_args, *args): mapped_xs = args[:num_mapped_args] pos_args = args[num_mapped_args:] # Note: We create "clean" environments for make_fx by suspending all dispatch keys # between Autograd and Python key. Currently, we only suspend functionalization but more can be # added when required. Will encounter two problems if we don't suspend functionalization: # # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper, # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching. # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to # fetch the proxy for the inputs and fail to capture any operations on them. # # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore, # when creating the output node, it fails to associate the wrapped tensor with its proxy. # Instead, it will create _tensor_constant as output. with suspend_functionalization(): with disable_proxy_modes_tracing(): def from_fun(t): if isinstance(t, torch.Tensor): if t.dtype != torch.bool: return torch.empty_strided( t.size(), t.stride(), dtype=t.dtype, requires_grad=t.requires_grad, ) else: return t.clone() return t example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]] example_pos_args = [ from_fun(arg) if isinstance(arg, torch.Tensor) else arg for arg in pos_args ] example_flat_out = pytree.tree_map( from_fun, f(*example_xs, *example_pos_args) ) if any( not isinstance(out, torch.Tensor) for out in example_flat_out if out is not None ): raise RuntimeError( "Expect outputs of map only contains tensors or None. " f"Got types {[type(out) for out in example_flat_out]}." ) example_grad = [from_fun(out) for out in example_flat_out] fw_graph = make_fx(f)(*example_xs, *example_pos_args) def joint_f(*example_args): joint_mapped_args = example_args[:joint_num_mapped] args = example_args[joint_num_mapped:] mapped_input = joint_mapped_args[:num_mapped_args] mapped_grads = joint_mapped_args[num_mapped_args:] def fw_with_masks(*args): fw_out = f(*args) return fw_out, [ True if isinstance(ret, torch.Tensor) and ret.requires_grad else False for ret in fw_out ] joint = create_joint(fw_with_masks, aot_config=dummy_aot_config) _, grads = joint( list(mapped_input) + list(args), [ grad for grad in mapped_grads if grad is not None and grad.requires_grad ], ) # In order to keep map functional for backward graph, # we clone outputs that are aliasing inputs input_storage = { StorageWeakRef(arg._typed_storage()) for arg in example_args if isinstance(arg, torch.Tensor) } def maybe_clone(t): if ( isinstance(t, torch.Tensor) and StorageWeakRef(t._typed_storage()) in input_storage ): return t.clone() return t return pytree.tree_map(maybe_clone, grads) joint_num_mapped = len(example_grad) + len(example_xs) joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args) return fw_graph, joint_graph def map_wrapper(f, xs, *args): flat_xs, xs_spec = pytree.tree_flatten(xs) if not all(isinstance(t, torch.Tensor) for t in flat_xs): raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.") num_mapped_args = len(flat_xs) shapes = [xs.shape for xs in flat_xs] leading_dim_size = shapes[0][0] if leading_dim_size == 0: raise RuntimeError("Leading dimensions of mapped xs cannot be 0.") if any(cur_shape[0] != leading_dim_size for cur_shape in shapes): raise RuntimeError( f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}." ) out_spec = None def flat_fn(*flat_args): xs = pytree.tree_unflatten(flat_args[:num_mapped_args], xs_spec) unflattened_out = f(xs, *flat_args[num_mapped_args:]) flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out) nonlocal out_spec out_spec = tmp_out_spec return flat_out return pytree.tree_unflatten( map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec ) class MapAutogradOp(torch.autograd.Function): @staticmethod def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): ctx.save_for_backward(*flat_args) ctx._joint_graph = joint_graph ctx._num_mapped_args = num_mapped_args with torch._C._AutoDispatchBelowAutograd(): return (*map_impl(fw_graph, num_mapped_args, *flat_args),) @staticmethod def backward(ctx, *flat_grads): fw_args = ctx.saved_tensors fw_mapped_args = fw_args[: ctx._num_mapped_args] pos_args = fw_args[ctx._num_mapped_args :] grads = map_impl( ctx._joint_graph, ctx._num_mapped_args + len(flat_grads), *fw_mapped_args, *flat_grads, *pos_args, ) return None, None, None, *grads def trace_map(proxy_mode, func_overload, f, num_mapped, *args): xs = list(args[:num_mapped]) pos_args = list(args[num_mapped:]) leading_dim_size = xs[0].shape[0] example_input = _unstack_pytree(xs)[0] body_graph = f if not isinstance(body_graph, torch.fx.GraphModule): body_graph = make_fx(body_graph)(*example_input, *pos_args) with disable_proxy_modes_tracing(): example_outs = body_graph(*example_input, *pos_args) def expand_tensor(t): if isinstance(t, torch.Tensor): return t.expand(leading_dim_size, *t.shape) return t expanded_outs = pytree.tree_map(expand_tensor, example_outs) next_name = None i = 0 while not next_name: candidate = f"body_graph_{i}" if hasattr(proxy_mode.tracer.root, candidate): i += 1 else: next_name = candidate proxy_mode.tracer.root.register_module(next_name, body_graph) node_args = (body_graph, num_mapped, *args) proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) out_proxy = proxy_mode.tracer.create_proxy( "call_function", func_overload, proxy_args, {}, name="map_impl" ) return track_tensor_tree( expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer ) def _unstack_pytree(xs): flat_xs, inspec = pytree.tree_flatten(xs) if not all(isinstance(xs, torch.Tensor) for xs in flat_xs): raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): raise RuntimeError( f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}" ) a = zip(*flat_xs) pytrees = [] for tuple in a: pytrees.append(pytree.tree_unflatten(tuple, inspec)) return pytrees def _stack_pytree(pytrees): flat_out = [] out_spec = None for pt in pytrees: flat_pt, out_spec = pytree.tree_flatten(pt) flat_out.append(flat_pt) b = zip(*flat_out) stacked_out = [] for leaves in b: if all(isinstance(leaf, torch.Tensor) for leaf in leaves): stacked_out.append(torch.stack(leaves)) elif all(leaf is None for leaf in leaves): # Backward graph can return None output when forward inputs doesn't require grad. # When we eagerly execute backward graph, we need to call _stack_pytree on its output, # therefore we need to deal with None output. stacked_out.append(None) else: raise RuntimeError(f"Cannot stack {leaves}.") return pytree.tree_unflatten(stacked_out, out_spec) @map_impl.py_impl(DispatchKey.CompositeExplicitAutograd) def map_dense(f, num_mapped_args, *args): xs = args[:num_mapped_args] pos_args = args[num_mapped_args:] pytrees = [] for inp in _unstack_pytree(xs): pytrees.append(f(*inp, *pos_args)) return _stack_pytree(pytrees) @map_impl.py_impl(DispatchKey.Autograd) def map_autograd(f, num_mapped_args, *args): fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args) flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args) return flat_out @map_impl.py_impl(ProxyTorchDispatchMode) def map_proxy_torch_dispatch_mode(f, num_mapped, *args): mode = _get_current_dispatch_mode() assert mode is not None, "Mode should always be enabled for python fallback key" with _pop_mode_temporarily() as mode: if mode.enable_tracing: return trace_map(mode, map_impl, f, num_mapped, *args) else: return map_impl(f, num_mapped, *args) @map_impl.py_impl(FakeTensorMode) def map_fake_tensor_mode(f, num_mapped, *args): return map_dense(f, num_mapped, *args) @map_impl.py_impl(DispatchKey.Functionalize) def map_func(f, num_mapped, *args): reapply_views = torch._C._functionalization_reapply_views_tls() xs = args[:num_mapped] pos_args = args[num_mapped:] unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views) unwrapped_args = _unwrap_all_tensors_from_functional( pos_args, reapply_views=reapply_views ) mode = "mutations_and_views" if reapply_views else "mutations" with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): functional_map_fn = functionalize(f, remove=mode) with disable_proxy_modes_tracing(): example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) if _has_potential_branch_input_mutation(f, example_inputs): raise UnsupportedAliasMutationException("torch.map is mutating the input!") if _has_potential_branch_input_alias(f, example_inputs): raise UnsupportedAliasMutationException("torch.map is aliasing the input!") map_return = map_impl( functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args ) return _wrap_all_tensors_to_functional(map_return, level=0) @map_impl.py_impl(torch._C._functorch.TransformType.Functionalize) def map_functionalize(interpreter, f, num_mapped, *args): """ Functionalization implementation for torch.map. Currently: 1. We don't allow any input mutation inside the map function 2. Our check for above condition is not exhaustive """ xs = args[:num_mapped] pos_args = args[num_mapped:] reapply_views = interpreter.functionalize_add_back_views() mode = "mutations_and_views" if reapply_views else "mutations" # At this point, we will see functionalized tensors, so need to unwrap them first unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views) unwrapped_args = _unwrap_all_tensors_from_functional( pos_args, reapply_views=reapply_views ) functional_map_fn = functionalize(f, remove=mode) with interpreter.lower(): with disable_proxy_modes_tracing(): example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) if _has_potential_branch_input_mutation(f, example_inputs): raise UnsupportedAliasMutationException("torch.map is mutating the input!") if _has_potential_branch_input_alias(f, example_inputs): raise UnsupportedAliasMutationException("torch.map is aliasing the input!") map_return = map_impl( functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args ) return _wrap_all_tensors_to_functional(map_return, level=interpreter.level()) # TODO(voz) Make this automatic for keys, this is very ugly atm map_impl.fallthrough(DispatchKey.PythonDispatcher) map_impl.fallthrough(DispatchKey.PythonTLSSnapshot) map_impl.fallthrough(DispatchKey.ADInplaceOrView) map_impl.fallthrough(DispatchKey.BackendSelect) map_impl.fallthrough(DispatchKey.AutocastCPU)