|
from dataclasses import dataclass |
|
from functools import partial |
|
import torch |
|
from torch.multiprocessing.reductions import StorageWeakRef |
|
|
|
import torch.utils._pytree as pytree |
|
|
|
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard |
|
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize |
|
from torch._ops import PyOperator |
|
from torch._subclasses.fake_tensor import FakeTensorMode |
|
from torch.fx.experimental.proxy_tensor import ( |
|
disable_proxy_modes_tracing, |
|
ProxyTorchDispatchMode, |
|
make_fx, |
|
track_tensor_tree, |
|
unwrap_proxy, |
|
) |
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata |
|
from torch.utils._python_dispatch import ( |
|
_get_current_dispatch_mode, |
|
_pop_mode_temporarily, |
|
) |
|
from torch.utils._pytree import tree_flatten |
|
|
|
|
|
@dataclass |
|
class UnsupportedAliasMutationException(RuntimeError): |
|
reason: str |
|
|
|
|
|
""" |
|
We're going to define a `cond` operation. |
|
In order to do this, we need implementations for each of the dispatch keys. |
|
""" |
|
cond = PyOperator("cond") |
|
|
|
|
|
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): |
|
assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors" |
|
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors" |
|
|
|
with disable_proxy_modes_tracing(): |
|
true_graph = make_fx(true_fn)(*operands) |
|
false_graph = make_fx(false_fn)(*operands) |
|
|
|
true_outs = [] |
|
false_outs = [] |
|
for node in true_graph.graph.nodes: |
|
if node.op == 'output': |
|
true_outs.extend(node.args) |
|
|
|
for node in false_graph.graph.nodes: |
|
if node.op == 'output': |
|
false_outs.extend(node.args) |
|
|
|
flat_true_outs, _ = pytree.tree_flatten(true_outs) |
|
flat_false_outs, _ = pytree.tree_flatten(false_outs) |
|
assert(len(flat_true_outs) == len(flat_false_outs)) |
|
|
|
for i in range(0, len(flat_true_outs)): |
|
true_out = flat_true_outs[i] |
|
false_out = flat_false_outs[i] |
|
assert true_out.meta['tensor_meta'] == false_out.meta['tensor_meta'] |
|
|
|
|
|
|
|
|
|
next_name = None |
|
i = 0 |
|
while not next_name: |
|
candidate = f"true_graph_{i}" |
|
if hasattr(proxy_mode.tracer.root, candidate): |
|
i += 1 |
|
else: |
|
next_name = candidate |
|
|
|
true_name = next_name |
|
false_name = f"false_graph_{i}" |
|
assert(not hasattr(proxy_mode.tracer.root, false_name)) |
|
|
|
proxy_mode.tracer.root.register_module(true_name, true_graph) |
|
proxy_mode.tracer.root.register_module(false_name, false_graph) |
|
|
|
args = (pred, true_graph, false_graph, operands) |
|
|
|
proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), args) |
|
|
|
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {}, |
|
name="conditional") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = false_fn(*operands) |
|
|
|
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) |
|
|
|
|
|
@cond.py_impl(DispatchKey.CUDA) |
|
@cond.py_impl(DispatchKey.CPU) |
|
def cond_dense(pred, true_fn, false_fn, operands): |
|
mode = _get_current_dispatch_mode() |
|
assert (mode is None), "Mode should never be enabled for CPU/CUDA key" |
|
if pred: |
|
return true_fn(*operands) |
|
else: |
|
return false_fn(*operands) |
|
|
|
|
|
@cond.py_impl(DispatchKey.AutogradCUDA) |
|
@cond.py_impl(DispatchKey.AutogradCPU) |
|
def cond_autograd(pred, true_fn, false_fn, *operands): |
|
|
|
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands]) |
|
assert all([not f.requires_grad for f in flat_operands |
|
if isinstance(f, torch.Tensor)]) |
|
|
|
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)) |
|
return cond(pred, true_fn, false_fn, *operands) |
|
|
|
|
|
@cond.py_impl(ProxyTorchDispatchMode) |
|
def inner(pred, true_fn, false_fn, operands): |
|
mode = _get_current_dispatch_mode() |
|
assert (mode is not None), "Mode should always be enabled for python fallback key" |
|
with _pop_mode_temporarily() as mode: |
|
res = trace_cond(mode, cond, pred, true_fn, false_fn, operands) |
|
return res |
|
|
|
|
|
@cond.py_impl(FakeTensorMode) |
|
def cond_fake_tensor_mode(pred, true_fn, false_fn, operands): |
|
true_outs = true_fn(*operands) |
|
flat_true_outs, _ = pytree.tree_flatten(true_outs) |
|
flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands)) |
|
if len(flat_true_outs) != len(flat_false_outs): |
|
raise RuntimeError("Unmatched number of outputs from cond() branches.") |
|
|
|
for true_out, false_out in zip(flat_true_outs, flat_false_outs): |
|
true_meta = _extract_tensor_metadata(true_out) |
|
false_meta = _extract_tensor_metadata(false_out) |
|
if true_meta != false_meta: |
|
raise RuntimeError( |
|
f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}") |
|
return true_outs |
|
|
|
|
|
|
|
@cond.py_impl(DispatchKey.PythonDispatcher) |
|
def cond_python_dispatcher(*args): |
|
_ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher)) |
|
return cond(*args) |
|
|
|
|
|
def _has_potential_branch_input_mutation(branch, fake_inputs): |
|
""" |
|
Dispatch-trace the branch with fake inputs and check if |
|
producing graph has mutable op on the input. This is |
|
bit restrictive as the branch must be traceable. |
|
""" |
|
try: |
|
gm = make_fx(branch)(*fake_inputs) |
|
except UnsupportedAliasMutationException: |
|
|
|
|
|
return True |
|
except Exception as e: |
|
raise e |
|
|
|
input_nodes = set() |
|
for node in gm.graph.nodes: |
|
if node.op == "placeholder": |
|
input_nodes.add(node) |
|
if node.op == "call_function": |
|
target = node.target |
|
if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable: |
|
for arg in node.args: |
|
if arg in input_nodes: |
|
return True |
|
|
|
return False |
|
|
|
def _has_potential_branch_input_alias(branch, fake_inputs): |
|
""" |
|
Dispatch-trace the branch with fake inputs and check if |
|
producing graph has output aliasing the branch input. This is |
|
bit restrictive as the branch must be traceable. |
|
""" |
|
try: |
|
gm = make_fx(branch)(*fake_inputs) |
|
except UnsupportedAliasMutationException: |
|
|
|
|
|
return True |
|
except Exception as e: |
|
raise e |
|
|
|
input_storages = set() |
|
for node in gm.graph.nodes: |
|
if node.op == "placeholder": |
|
input_storages.add(StorageWeakRef(node.meta['val']._typed_storage())) |
|
|
|
outs, _ = pytree.tree_flatten(gm(*fake_inputs)) |
|
for out in outs: |
|
if isinstance(out, torch.Tensor) and StorageWeakRef(out._typed_storage()) in input_storages: |
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
@cond.py_impl(torch._C._functorch.TransformType.Functionalize) |
|
def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs): |
|
""" |
|
Functionalization implementation for torch.cond. Currently: |
|
1. We don't allow any input mutation inside the branches |
|
2. Our check for above condition is not exhaustive |
|
""" |
|
reapply_views = interpreter.functionalize_add_back_views() |
|
mode = 'mutations_and_views' if reapply_views else 'mutations' |
|
|
|
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views) |
|
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views) |
|
|
|
functional_true_fn = functionalize(true_fn, remove=mode) |
|
functional_false_fn = functionalize(false_fn, remove=mode) |
|
|
|
with interpreter.lower(): |
|
fake_tensor_mode = FakeTensorMode() |
|
with fake_tensor_mode as ft_mode: |
|
for branch in [functional_true_fn, functional_false_fn]: |
|
def convert(x): |
|
return ft_mode.fake_tensor_converter(ft_mode, x) |
|
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs) |
|
if _has_potential_branch_input_mutation(branch, fake_inputs): |
|
raise UnsupportedAliasMutationException("One of torch.cond branch " |
|
"might be modifying the input!") |
|
for branch in [true_fn, false_fn]: |
|
def convert(x): |
|
return ft_mode.fake_tensor_converter(ft_mode, x) |
|
fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs) |
|
if _has_potential_branch_input_alias(branch, fake_inputs): |
|
raise UnsupportedAliasMutationException("One of torch.cond branch " |
|
"might be aliasing the input!") |
|
|
|
cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs) |
|
return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level()) |
|
|
|
|
|
cond.fallthrough(DispatchKey.PythonTLSSnapshot) |
|
cond.fallthrough(DispatchKey.ADInplaceOrView) |
|
cond.fallthrough(DispatchKey.BackendSelect) |
|
|