Spaces:
Runtime error
Runtime error
import itertools | |
from contextlib import nullcontext | |
from functools import partial, wraps | |
from typing import Any, Callable, Dict, List, Optional, Tuple | |
from unittest.mock import patch | |
import torch | |
import torch.nn as nn | |
import torch.utils._pytree as pytree | |
import torch.utils.dlpack | |
from torch import Tensor | |
from torch._dispatch.python import enable_python_dispatcher | |
from torch._dynamo import compiled_autograd | |
from torch._dynamo.utils import dynamo_timed, preserve_rng_state | |
from torch._guards import detect_fake_mode | |
from torch._subclasses import FakeTensor, FakeTensorMode | |
from torch.fx.experimental.proxy_tensor import make_fx | |
from torch.fx.experimental.symbolic_shapes import ( | |
ShapeEnv | |
) | |
from torch.utils._python_dispatch import is_traceable_wrapper_subclass | |
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions | |
from . import config | |
from .partitioners import default_partition | |
from ._aot_autograd.utils import ( # noqa: F401 | |
strict_zip, | |
_get_symint_hints, | |
KNOWN_TYPES, | |
partial_flatten_asdict, | |
normalize_as_list, | |
_get_autocast_states, | |
make_boxed_func, | |
make_boxed_compiler, | |
call_func_at_runtime_with_args, | |
create_tree_flattened_fn, | |
maybe_to_fresh_input, | |
) | |
from ._aot_autograd.logging_utils import ( # noqa: F401 | |
graph_being_compiled, | |
nth_graph, | |
model_name, | |
set_model_name, | |
get_aot_compilation_context, | |
get_aot_graph_name, | |
get_graph_being_compiled, | |
track_graph_compiling, | |
callback_set, | |
setup_stacktrace_preservation_hooks, | |
describe_input, | |
format_guard_bug_msg, | |
) | |
from ._aot_autograd.functional_utils import ( # noqa: F401 | |
is_fun, | |
to_fun, | |
from_fun, | |
sync_functional_tensor, | |
has_metadata_mutation, | |
has_data_mutation, | |
are_all_mutations_hidden_from_autograd, | |
are_all_mutations_under_no_grad_or_inference_mode, | |
gen_alias_from_base, | |
assert_functional_graph, | |
_get_mutation_type, | |
_check_if_mutation_can_be_in_graph, | |
) | |
from ._aot_autograd.schemas import ( # noqa: F401 | |
OutputType, | |
OutputAliasInfo, | |
MutationType, | |
InputAliasInfo, | |
SubclassCreationMeta, | |
ViewAndMutationMeta, | |
SubclassMeta, | |
TensorAlias, | |
BackwardSignature, | |
GraphOutputName, | |
GraphInputName, | |
FQN, | |
GraphSignature, | |
AOTConfig, | |
) | |
from ._aot_autograd.subclass_utils import ( # noqa: F401 | |
requires_subclass_dispatch, | |
unwrap_tensor_subclasses, | |
wrap_tensor_subclasses, | |
wrap_tensor_subclasses_maybe_joint, | |
create_metadata_for_subclass, | |
) | |
from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 | |
run_functionalized_fw_and_collect_metadata, | |
) | |
from ._aot_autograd.input_output_analysis import ( # noqa: F401 | |
remove_dupe_metadata, | |
create_synthetic_base_metadata, | |
_tensors_definitely_do_not_overlap, | |
compute_overlapping_inputs, | |
create_graph_signature, | |
) | |
from ._aot_autograd.traced_function_transforms import ( # noqa: F401 | |
fn_input_mutations_to_outputs, | |
fn_prepped_for_autograd, | |
create_functionalized_fn, | |
create_functionalized_rng_ops_wrapper, | |
aot_dispatch_subclass, | |
create_functional_call, | |
create_joint, | |
) | |
from ._aot_autograd.runtime_wrappers import ( # noqa: F401 | |
create_runtime_wrapper, | |
functionalized_rng_runtime_epilogue, | |
aot_dispatch_subclass_wrapper, | |
aot_wrapper_dedupe, | |
aot_wrapper_synthetic_base, | |
merge_view_inputs, | |
) | |
from ._aot_autograd.dispatch_and_compile_graph import ( # noqa: F401 | |
aot_dispatch_base_graph, | |
aot_dispatch_autograd_graph, | |
) | |
from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401 | |
aot_dispatch_base, | |
aot_dispatch_autograd, | |
) | |
zip = strict_zip | |
# This global counter increments every time we compile a graph with | |
# AOTAutograd. You can use this to correlate runtime error messages | |
# with compile time (e.g., if you get an error at runtime saying | |
# compiled graph 3 failed, you can set a breakpoint at compile time | |
# for this graph number to investigate further at compile time.) | |
# | |
# NB: this is different from get_aot_compilation_context, which tracks | |
# each underlying graph that is compiled. In contrast, AOT_COUNTER | |
# corresponds to top-level invocations of aot_module/aot_function; | |
# one counter is allocated per entire compiled block (but this block | |
# may involve compiling multiple subgraphs; e.g., for forwards/backwards) | |
AOT_COUNTER = itertools.count() | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# | |
# AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation | |
# that are external to the graph (they show up as side effects in some way when you run the graph). | |
# | |
# Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions | |
# and what they're compiled graphs looks like. | |
# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them. | |
# | |
# Note [AOT Autograd: input data mutations] | |
# | |
# If we compile a function that mutates inputs, then those input mutations are real side effects | |
# that a user expects to see after running the compiled graph. | |
# However, the graph that we want to send to a backend needs to be *entirely* functional. | |
# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile | |
# but we update the graph to return (updated_inputs, user_outputs). | |
# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals. | |
# | |
# Example: original user code: | |
# def f(x): | |
# x.mul_(2) | |
# out = x.mul(3) | |
# return out | |
# | |
# After AOT Autograd compiles, we end up with a: | |
# (a) compiled graph | |
# (b) autograd.Function.forward() method, that executes the compiled graph | |
# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue | |
# | |
# The output of (a, b, c) are all written below. | |
# | |
# def compiled_forward_graph(x): | |
# x_updated = x.mul(2) | |
# out = x_updated.mul(3) | |
# return x_updated, out | |
# | |
# # x_updated gets a gradient in the compiled backward | |
# def compiled_backward_graph(grad_x_updated, grad_out): | |
# grad_x = ... | |
# return grad_x | |
# | |
# def autograd.Function.forward(x): | |
# x_updated, out = compiled_forward_graph(x) | |
# return x_updated, out | |
# | |
# def compiled_wrapper(x): | |
# x_updated, out = autograd.Function.apply(x) | |
# x.copy_(x_updated) | |
# return out | |
# | |
# Another important thing to note is that updated inputs (due to data mutations) *do* participate | |
# in the compiled backward graph! Since the compiled forward graph gets N extra outputs | |
# (due to updated inputs showing up as graph outputs), | |
# The compiled backward gets an additional N inputs. | |
# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input | |
# back to the original input. | |
# Note [AOT Autograd: input metadata mutations] | |
# | |
# For the same reason as input mutations, we also don't put input metadata mutations in the graph. | |
# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph | |
# | |
# Example: original user code: | |
# def f(x): | |
# x.t_() | |
# out = x.mul(3) | |
# return out | |
# | |
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): | |
# def compiled_forward_graph(x): | |
# x_updated = x.t() | |
# out = x_updated.mul(3) | |
# return x_updated, out | |
# | |
# # x_updated does *not* get a gradient in the compiled backward | |
# def compiled_backward_graph(grad_out): | |
# grad_x = ... | |
# return grad_x | |
# | |
# def autograd.Function.forward(x): | |
# x_updated, out = compiled_forward_graph(x) | |
# return x_updated, out | |
# | |
# def compiled_wrapper(x): | |
# x_updated, out = autograd.Function.apply(x) | |
# x.as_strided_(x_updated) | |
# return out | |
# Note [AOT Autograd: outputs aliasing inputs or intermediates!] | |
# | |
# AOT Autograd needs special handling for outputs that alias graph inputs or intermediates! | |
# Why? | |
# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated. | |
# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph, | |
# in an epilogue. | |
# For outputs that alias inputs, we do the following: | |
# (a) *still* return the aliased output as a graph output | |
# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output. | |
# | |
# For outputs that alias *intermediates*, we do the following: | |
# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward | |
# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output). | |
# You might wonder why we return the aliased output directly in the graph (and making the graph compute it), | |
# only to not return it and instead generate a fresh alias off of the intermediate, | |
# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons: | |
# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call | |
# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance. | |
# This can result in problems if a user later tries to .view() that output expecting it to have one set of strides, | |
# when it has a different set of strides. | |
# By including the view op directly in the graph, inductor takes that into account when deciding what memory format | |
# the graph intermediate should be. | |
# | |
# Another important thing to note is how our traced backward() graph handles aliases. | |
# (this applies to outputs aliasing inputs, outputs aliasing intermediates, | |
# *and* updated inputs returned in the compiled forward due to metadata-only mutations). | |
# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph | |
# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly | |
# at the end of the forward. | |
# | |
# Example: original user code: | |
# def f(x): | |
# out1 = x.t() | |
# intermediate = x.mul(2) | |
# out2 = intermediate.view(-1) | |
# return out1, out2 | |
# | |
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): | |
# def compiled_forward_graph(x): | |
# out1 = x.t() | |
# intermediate = x.mul(2) | |
# out2 = intermediate.view(-1) | |
# # the compiled graph also returns the intermediate | |
# return out1, out2, intermediate | |
# | |
# # intermediate gets a gradient in the compiled backward. | |
# # both output aliases (out1 and out2) do not. | |
# def compiled_backward_graph(grad_intermediate): | |
# grad_x = ... | |
# return grad_x | |
# | |
# def autograd.Function.forward(x): | |
# out1, out2, intermediate = compiled_forward_graph(x) | |
# return out1, out2, intermediate | |
# | |
# def compiled_wrapper(x): | |
# out1, out2, intermediate = autograd.Function.apply(x) | |
# # regenerate out1 from the input | |
# out1_regenerated = out1._view_func(x) | |
# # regenerate out1 from the intermediate | |
# out2_regenerated = out2._view_func(intermediate) | |
# return out1_regenerated, out2_regenerated | |
# Note [AOT Autograd: mutations to inputs that alias other inputs] | |
# | |
# Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input. | |
# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other. | |
# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias | |
# given the mutation that occurred. | |
# | |
# This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input | |
# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base | |
# inside of the compiled function. | |
# | |
# This logic is fully encapsulated in aot_wrapper_synthetic_base() | |
# | |
# Example: original user code: | |
# def f(x, x_view): | |
# x.mul_(2) | |
# out = x * x_view | |
# return out | |
# f(x, x.view(-1)) | |
# | |
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): | |
# def compiled_forward_graph(base) | |
# x = generate_x(base) | |
# x_view = generate_x_view(base) | |
# x_updated = x.mul(2) | |
# x_view_updated = x_updated.view(-1) | |
# out = x_updated * x_view_updated | |
# return x_updated, out | |
# | |
# # The calling convention change from (aliases) -> (base) happens | |
# # *outside* of the autograd.Function.forward(). | |
# # That means the forward() only has 1 input (base), | |
# # and the backward() only has 1 output (grad_base) | |
# def compiled_backward_graph(grad_out): | |
# grad_base = ... | |
# return grad_base | |
# | |
# def autograd.Function.forward(base): | |
# x_updated, out = compiled_forward_graph(base) | |
# return x_updated, out | |
# | |
# # The compiled wrapper is where we create synthetic bases. | |
# # The info on which inputs are mutated is also tracked *before* synthetic base creation. | |
# def compiled_wrapper(x, x_view): | |
# base = merge_view_inputs(x, x_view) | |
# x_updated, out = autograd.Function.apply(base) | |
# # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view. | |
# x.copy_(x_updated) | |
# return out | |
# Note [AOT Autograd: Views to avoid tangents aliasing inputs] | |
# | |
# We view every forward output when creating out tangent tensors to handle the problematic | |
# case in which a subclass does extra aliasing between graph outputs/inputs in a way that | |
# is not visible above the sublass. | |
# | |
# Ordinarily, when constructing the joint function that we want to trace in AOTAutograd, | |
# we're guaranteed that the tangent tensors that we pass | |
# into the joint are distinct tensors from the primals. This is because when | |
# decide which forward outputs to create tangents for, we only create tangents | |
# for forward outputs that are not aliases of inputs (See Note | |
# [AOT Autograd: outputs aliasing inputs or intermediates!]). | |
# | |
# However, when wrapper tensor subclasses enter the picture, it is possible | |
# to have an output of the forward that is a subclass that is not an | |
# input / alias of an input, but one of its inner tensors is an alias! | |
# NestedTensor is an example: Performing an out-of-place pointwise op on a | |
# NestedTensor constructs a fresh NestedTensor that holds onto the input's | |
# offsets tensor directly. | |
# | |
# Having tangent tensors that are the same as the (primal) forward inputs, | |
# can cause problems during tracing as make_fx() will specialize on our | |
# duplicate inputs: If we passed in the same tensor for primals_1 and | |
# tangents_1 during tracing, make_fx() will happily sub out all usages of | |
# tangents_1 with primals_1 in the graph, which is not what we want. | |
# | |
# To work around this, we view every forward output when creating out tangent | |
# tensors so that tangents can never be the same as forward inputs even if | |
# forward inputs alias forward outputs. | |
# | |
# | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
aot_autograd_decompositions = {} | |
def create_aot_dispatcher_function( | |
flat_fn, flat_args: List[Any], aot_config: AOTConfig | |
): | |
""" | |
Traces the forward and backward graphs of the attr:`flat_fn` to generate a | |
joint graph. The joint graph is an Fx graph with Aten ops. Please refer to | |
the tracing mechanism to understand the graph capturing details. | |
The joint graph is then passed through attr:`partition_fn` to isolate the | |
forward and backward portions, which are then respectively compiled via the | |
provided attr:`fw_compiler` and attr:`bw_compiler`. | |
The resulting compiled forward and backward graphs are then wrapped up in a | |
``torch.autograd.Function`` object. | |
The calling convention here is that the first aot_config.num_params_buffers | |
inputs in flat_args are parameters and buffers, and the rest are inputs. | |
We use this to assume that parameters/buffer's shapes don't change. | |
Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export) | |
When aot_config.is_export is True, we return an FX graph + metadata | |
When aot_config.is_export is False, we return an ordinary runtime function | |
""" | |
# This is the main entry point. | |
# TODO: Chillee argues that dynamo itself should pass in fake tensors to | |
# the list of arguments when compiling; at the moment we do not do this | |
if aot_config.decompositions is None: | |
aot_config.decompositions = {} | |
aot_config.decompositions = { | |
**aot_autograd_decompositions, | |
**aot_config.decompositions, | |
} | |
if config.functionalize_rng_ops: | |
# Update the decompositions with functionalized random decompositions | |
aot_config.decompositions = { | |
**rng_decompositions, | |
**aot_config.decompositions, | |
} | |
# Check flat_args to see if they're already fake. If so, use that fake | |
# mode instead. | |
fake_mode = detect_fake_mode(flat_args) | |
if fake_mode is None: | |
shape_env = ShapeEnv() if aot_config.dynamic_shapes else None | |
fake_mode = FakeTensorMode(shape_env=shape_env) | |
else: | |
shape_env = fake_mode.shape_env | |
python_dispatcher_mode = ( | |
enable_python_dispatcher() if shape_env is not None else nullcontext() | |
) | |
with torch.autograd.set_multithreading_enabled( | |
False | |
), preserve_rng_state(), fake_mode, python_dispatcher_mode, PhiloxStateTracker(): | |
def process_inputs(flat_args): | |
def convert(idx, x): | |
if shape_env is not None: | |
from torch._dynamo.source import ConstantSource | |
if isinstance(x, int): | |
# We always specialize on scalar values in export. | |
if aot_config.is_export: | |
return x | |
source = ConstantSource(f"sym_{idx}") | |
return shape_env.create_symintnode( | |
shape_env.create_symbol(x, source), | |
hint=x, | |
source=source | |
) | |
if not isinstance(x, torch.Tensor): | |
return x | |
if isinstance(x, FakeTensor): | |
assert x.fake_mode is fake_mode | |
return x | |
if is_traceable_wrapper_subclass(x): | |
attrs, _ = x.__tensor_flatten__() | |
if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): | |
assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs) | |
return x | |
# see note [Tensor Fakification and Symbol Caching] | |
symbolic_context = None | |
source = None | |
if tracing_context := torch._guards.TracingContext.try_get(): | |
if x in tracing_context.tensor_to_context: | |
symbolic_context = tracing_context.tensor_to_context[x] | |
source = symbolic_context.tensor_source | |
if ( | |
idx < aot_config.num_params_buffers | |
and config.static_weight_shapes | |
and not symbolic_context | |
): | |
# TODO: Ensure that this codepath is never exercised from | |
# Dynamo | |
return fake_mode.from_tensor(x, static_shapes=True) | |
return fake_mode.from_tensor( | |
x, static_shapes=False, symbolic_context=symbolic_context, source=source | |
) | |
return [convert(idx, x) for idx, x in enumerate(flat_args)] | |
fake_flat_args = process_inputs(flat_args) | |
needs_autograd = ( | |
any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)) | |
and torch.is_grad_enabled() | |
) | |
with enable_python_dispatcher(): | |
# Patch set_rng_state as set_rng_state with fake tensors is | |
# nonsensical. This does not affect the collection of metadata. | |
with patch("torch.cuda.set_rng_state", lambda *args: None): | |
fw_metadata = run_functionalized_fw_and_collect_metadata( | |
flat_fn, | |
keep_input_mutations=aot_config.keep_inference_input_mutations, | |
is_train=needs_autograd, | |
)(*fake_flat_args) | |
req_subclass_dispatch = requires_subclass_dispatch(fake_flat_args, fw_metadata) | |
if needs_autograd and not any(x.requires_grad for x in fw_metadata.output_info): | |
# We realized that none of the outputs require grad, | |
# so we actually have an inference graph. | |
needs_autograd = False | |
# A bit silly: right now in the subclass codepath, our ViewAndMutationMeta | |
# changes depending on whether we pass in is_train / keep_input_mutations, | |
# so we're forced to recompute the metadata. | |
# TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata | |
# so that this is unnecessary. | |
if req_subclass_dispatch: | |
fw_metadata = run_functionalized_fw_and_collect_metadata( | |
flat_fn, | |
keep_input_mutations=aot_config.keep_inference_input_mutations and not needs_autograd, | |
is_train=needs_autograd, | |
)(*fake_flat_args) | |
else: | |
fw_metadata = ViewAndMutationMeta( | |
input_info=fw_metadata.input_info, | |
output_info=fw_metadata.output_info, | |
num_intermediate_bases=fw_metadata.num_intermediate_bases, | |
keep_input_mutations=aot_config.keep_inference_input_mutations and not needs_autograd, | |
traced_tangents=fw_metadata.traced_tangents, | |
subclass_inp_meta=fw_metadata.subclass_inp_meta, | |
subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, | |
subclass_tangent_meta=fw_metadata.subclass_tangent_meta, | |
is_train=needs_autograd, | |
) | |
if fw_metadata.num_intermediate_bases > 0: | |
assert not req_subclass_dispatch, f"""\ | |
torch.compile is currently being used with tensor subclass inputs: | |
{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs | |
that alias one another, which is currently unsupported in the subclass use case. If you run into this, | |
please file a github issue""" | |
if aot_config.is_export: | |
# aot_export: ban input metadata mutations for now to keep shared code paths simpler. | |
# Keeping .resize_() in the graph will require some work | |
# Allowing it but keeping the graph functional will require some calling convention changes. | |
if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: | |
raise RuntimeError(f"""\ | |
Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`. | |
This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. | |
fw_metadata={str(fw_metadata)}""") | |
# In export, banning data mutations on inputs that require grad for now. | |
# This should be rare, and is tricky to get right. When we trace the backward, | |
# we currently trace with autograd.grad instead of .backward(), which makes it difficult | |
# to ensure that we run autograd all the way through the input **before** it saw the mutation. | |
if len([x for x in fw_metadata.input_info if x.requires_grad and x.mutates_data]) != 0: | |
raise RuntimeError(f"""\ | |
Found a graph input that requires gradients, and received a mutation. | |
This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. | |
fw_metadata={str(fw_metadata)}""") | |
if req_subclass_dispatch: | |
raise RuntimeError("""\ | |
aot_export is not currently supported with traceable tensor subclass. | |
If you need this feature, please comment on <CREATE_ISSUE_LINK>""") | |
# Need to decide on a strategy for functionalized RNG: toggling via global config seems bad, | |
# and turning it on will require a non-trivial calling convention change for any export runtime. | |
if config.functionalize_rng_ops: | |
raise RuntimeError("""\ | |
Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue, | |
or otherwise set torch._functorch.config.functionalize_rng_ops = False.""") | |
# crappy version of dispatcher | |
# TODO: Do this properly | |
if needs_autograd: | |
# For now, aot_dispatch_autograd knows to explicitly return a graph | |
# when run with export, and an opaque callable otherwise. | |
# In theory we could factor these out, but I wanted to let the dust | |
# settle on how functionalized rng fits into export first. | |
compiler_fn = aot_dispatch_autograd_graph if aot_config.is_export else aot_dispatch_autograd | |
else: | |
# aot_dispatch_base_graph contains only the "graph bits", while aot_dispatch_base | |
# includes some extra work around handling a runtime epilogue. | |
compiler_fn = aot_dispatch_base_graph if aot_config.is_export else aot_dispatch_base | |
compiler_fn = partial(aot_wrapper_synthetic_base, compiler_fn=compiler_fn, needs_autograd=needs_autograd) | |
compiler_fn = partial(aot_wrapper_dedupe, compiler_fn=compiler_fn) | |
# You can put more passes here | |
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata) | |
if aot_config.is_export: | |
mutated_user_inp_locs = [ | |
idx - aot_config.num_params_buffers | |
for idx in fw_metadata.mutated_inp_runtime_indices | |
if idx >= aot_config.num_params_buffers | |
] | |
if len(mutated_user_inp_locs) > 0: | |
raise RuntimeError(f""" | |
Found following user inputs located at {mutated_user_inp_locs} are mutated. This is currently banned in the aot_export workflow. | |
If you need this functionality, please file a github issue. | |
fw_metadata={str(fw_metadata)}""") | |
# During export, we don't get back a callable - we get back the raw fx graph | |
# (either a joint or an inference-only graph) | |
assert isinstance(compiled_fn, torch.fx.GraphModule) | |
return compiled_fn, fw_metadata | |
if not hasattr(compiled_fn, "_boxed_call"): | |
compiled_fn = make_boxed_func(compiled_fn) | |
return compiled_fn | |
def aot_function( | |
fn: Callable, | |
fw_compiler: Callable, | |
bw_compiler: Optional[Callable] = None, | |
partition_fn: Callable = default_partition, | |
decompositions: Optional[Dict] = None, | |
num_params_buffers: int = 0, | |
keep_inference_input_mutations: bool = False, | |
inference_compiler: Optional[Callable] = None, | |
*, | |
# Whether or not to trace with dynamic shapes | |
dynamic=False, | |
enable_log=True, | |
) -> Callable: | |
""" | |
Traces the forward and backward graph of :attr:`fn` using torch dispatch | |
mechanism, and then compiles the generated forward and backward graphs | |
through :attr:`fw_compiler` and :attr:`bw_compiler`. | |
:func:`aot_function` traces the forward and backward graph ahead of time, | |
and generates a joint forward and backward graph. :attr:`partition_fn` is | |
then used to separate out forward and backward graphs. The partitioner | |
function can be used to perform optimizations such as recomputation. One can | |
set `decompositions` dictionary to decompose the operators into a sequence | |
of core or simpler operators supported by the backend compilers. | |
.. warning:: | |
This API is experimental and likely to change. | |
Args: | |
fn (Callable): A Python function that takes one ore more arguments. Must | |
return one or more Tensors. | |
fw_compiler (Callable): A Python function that accepts an Fx graph with | |
Aten ops and input args, and returns a Callable that semantically is | |
equivalent to the input Fx graph. | |
bw_compiler (Optional[Callable]): A Python function that accepts an | |
Fx graph with Aten ops and input args, and returns a Callable that | |
semantically is equivalent to the input Fx graph. Default: None | |
(when None, it defaults to the :attr:`fw_compiler`) | |
partition_fn (Callable): A Python function that takes a joint forward | |
and backward graph, and partitions it into separate forward and | |
backward graphs. | |
decompositions (Dict): A dictionary to define the decomposition of | |
larger Aten ops into simpler or core Aten ops. | |
inference_compiler (Optional[Callable]): A Python function that accepts an | |
Fx graph with Aten ops and input args, and returns a Callable that | |
semantically is equivalent to the input Fx graph. inference_compiler is invoked | |
if no autograd is needed. Default: None | |
(when None, it defaults to the :attr:`fw_compiler`) | |
Returns: | |
Returns a ``Callable`` that retains the eager behavior of the original | |
:attr:`fn`, but with forward and backward graph compiled via | |
:attr:`fw_compile` and :attr:`bw_compile`. | |
A simple example usage of :func:`aot_function` is as follows. This example | |
will print the forward and backward graphs of the function ``fn`` | |
>>> fn = lambda x : x.sin().cos() | |
>>> def print_compile_fn(fx_module, args): | |
>>> print(fx_module) | |
>>> return fx_module | |
>>> aot_fn = aot_function(fn, print_compile_fn) | |
>>> x = torch.randn(4, 5, requires_grad=True) | |
>>> aot_fn(x) | |
""" | |
if bw_compiler is None: | |
bw_compiler = fw_compiler | |
if inference_compiler is None: | |
inference_compiler = fw_compiler | |
aot_config = AOTConfig( | |
fw_compiler=fw_compiler, | |
bw_compiler=bw_compiler, | |
inference_compiler=inference_compiler, | |
partition_fn=partition_fn, | |
decompositions=decompositions, | |
num_params_buffers=num_params_buffers, | |
aot_id=next(AOT_COUNTER), | |
keep_inference_input_mutations=keep_inference_input_mutations, | |
dynamic_shapes=dynamic, | |
aot_autograd_arg_pos_to_source=None, | |
is_export=False, | |
no_tangents=False, | |
enable_log=enable_log, | |
) | |
cached_res = None | |
def returned_function(*args, **kwargs): | |
nonlocal cached_res | |
# Now flatten the tensor args | |
flat_args = pytree.arg_tree_leaves(*args, **kwargs) | |
# Compile the function and save it in the cache | |
if cached_res is None: | |
flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs) | |
compiled_fn = create_aot_dispatcher_function( | |
flat_fn, | |
flat_args, | |
aot_config, | |
) | |
cached_res = (compiled_fn, out_spec) | |
cached_fn, out_spec = cached_res | |
out = cached_fn(flat_args) | |
return out_spec.unflatten(out) | |
return returned_function | |
def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: | |
""" | |
Traces the forward and backward graph of :attr:`mod` using torch dispatch | |
tracing mechanism. It is wrapper function, that underneath uses | |
:func:`aot_function` to perform tracing and compilation. | |
:func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs | |
to a new callable which is then compiled through :func:`aot_function`. | |
.. warning:: | |
This API is experimental and likely to change. | |
Args: | |
mod (Callable): A ``nn.Module`` module. | |
args : args to be passed to :func:`aot_function` | |
kwargs : kwargs to be passed to :func:`aot_function` | |
Returns: | |
Returns a ``nn.Module`` that retains the eager behavior of the original | |
:attr:`mod`, but with forward and backward graph compiled. | |
""" | |
# See Note: [Fake Modules and AOTAutograd] | |
torch._dynamo.utils.assert_no_fake_params_or_buffers(mod) | |
def functional_call(named_params, named_buffers, *args, **kwargs): | |
params_and_buffers = {**named_params, **named_buffers} | |
return torch.func.functional_call(mod, params_and_buffers, args, kwargs) | |
named_params = dict(mod.named_parameters(remove_duplicate=False)) | |
named_buffers = dict(mod.named_buffers(remove_duplicate=False)) | |
num_params_buffers = len(named_params) + len(named_buffers) | |
compiled_f = aot_function( | |
functional_call, *args, num_params_buffers=num_params_buffers, **kwargs | |
) | |
class AOTModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.orig_module = mod | |
def forward(self, *args, **kwargs): | |
return compiled_f( | |
named_params, | |
named_buffers, | |
*args, | |
**kwargs, | |
) | |
return AOTModule() | |
def aot_module_simplified( | |
mod: nn.Module, | |
args, | |
fw_compiler: Callable, | |
bw_compiler: Optional[Callable] = None, | |
partition_fn: Callable = default_partition, | |
decompositions: Optional[Dict] = None, | |
keep_inference_input_mutations=False, | |
inference_compiler: Optional[Callable] = None, | |
) -> nn.Module: | |
""" | |
This is the simplified or low overhead version of aot_module. For frontends | |
like TorchDynamo, the input functions/modules to AOT are static and have | |
unpacked inputs/outputs. This gives us an opportunity to remove the | |
(1) pytree overhead to parse inputs/outputs, | |
(2) AOT Autograd cache, | |
(3) Reading of params/buffers in every forward call | |
:func:`aot_module_simplified` removes these overheads. | |
""" | |
params = { | |
**dict(mod.named_parameters(remove_duplicate=False)), | |
**dict(mod.named_buffers(remove_duplicate=False)), | |
} | |
params_flat, params_spec = pytree.tree_flatten(params) | |
params_flat = list(params_flat) | |
params_len = len(params_flat) | |
functional_call = create_functional_call(mod, params_spec, params_len) | |
if bw_compiler is None: | |
bw_compiler = fw_compiler | |
if inference_compiler is None: | |
inference_compiler = fw_compiler | |
seen_sources = set() | |
full_args = [] | |
# First, the params | |
full_args.extend(params_flat) | |
if tracing_context := torch._guards.TracingContext.try_get(): | |
tracing_context.params_flat = params_flat | |
aot_autograd_arg_pos_to_source = None | |
# Then, the params 1:1 mapped sources, if relevant. | |
if hasattr(mod, "_param_name_to_source"): | |
aot_autograd_arg_pos_to_source = [] | |
# We now know this came from dynamo, and (1) we care about guards, | |
# so setting up aot_autograd_arg_pos_to_source for downstream dedup guards | |
# can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. | |
for name in params.keys(): | |
assert name in mod._param_name_to_source, f"{name} not found." | |
source = mod._param_name_to_source[name] | |
assert source not in seen_sources, source | |
seen_sources.add(source) | |
aot_autograd_arg_pos_to_source.append(source) | |
# Next, the input args | |
full_args.extend(args) | |
if hasattr(mod, "graph"): | |
# Non dynamo entrypoints can get to here... | |
for i, node in enumerate(mod.graph.nodes): | |
if node.op == "placeholder": | |
if hasattr(node, "_dynamo_source"): | |
# ... but not here! | |
if aot_autograd_arg_pos_to_source is None: | |
aot_autograd_arg_pos_to_source = [] | |
source = node._dynamo_source | |
assert source not in seen_sources, source | |
seen_sources.add(source) | |
aot_autograd_arg_pos_to_source.append(source) | |
if aot_autograd_arg_pos_to_source is not None: | |
assert len(full_args) == len(aot_autograd_arg_pos_to_source) | |
dynamic_shapes = False | |
for x in full_args: | |
if isinstance(x, FakeTensor): | |
dynamic_shapes = x.fake_mode.shape_env is not None | |
break | |
aot_config = AOTConfig( | |
fw_compiler=fw_compiler, | |
bw_compiler=bw_compiler, | |
inference_compiler=inference_compiler, | |
partition_fn=partition_fn, | |
decompositions=decompositions, | |
num_params_buffers=params_len, | |
aot_id=next(AOT_COUNTER), | |
keep_inference_input_mutations=keep_inference_input_mutations, | |
dynamic_shapes=dynamic_shapes, | |
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source, | |
is_export=False, | |
no_tangents=False, | |
) | |
with compiled_autograd.disable(): | |
compiled_fn = create_aot_dispatcher_function( | |
functional_call, | |
full_args, | |
aot_config, | |
) | |
# TODO: There is something deeply wrong here; compiled_fn running with | |
# the boxed calling convention, but aot_module_simplified somehow | |
# historically returned a function that was not the boxed calling | |
# convention. This should get fixed... | |
def forward(*runtime_args): | |
full_args = [] | |
full_args.extend(params_flat) | |
full_args.extend(runtime_args) | |
return compiled_fn(full_args) | |
# Just for convenience | |
forward.zero_grad = mod.zero_grad | |
forward.named_parameters = mod.named_parameters | |
forward.named_buffers = mod.named_buffers | |
return forward | |
def aot_export_module( | |
mod: nn.Module, | |
args, | |
*, | |
decompositions: Optional[Dict] = None, | |
# If true, we'll return a joint forward-backward graph, | |
# As well as metadata on the loss + gradients in the backward. | |
trace_joint: bool, | |
# If trace_joint is True, we expect your module to return a scalar loss. | |
# Your module can return multiple outputs, so you must specify which output the loss is. | |
output_loss_index: Optional[int] = None, | |
) -> Tuple[torch.fx.GraphModule, GraphSignature]: | |
""" | |
This function takes in a module, and returns: | |
(1) an FX graph that can be exported | |
(2) some metadata about the graph | |
If `trace_joint=True` we will return a joint graph of the forward + backward. | |
The traced FX graph will have the following properties compared to the original module: | |
(1) Inputs and outputs to the module will be pytree-flattened | |
(2) Parameters and buffers on the module will be lifted into graph inputs, | |
graph_inputs = (*parameters, *buffers, *user_inputs) | |
(3) The graph will be fully functionalized | |
(4) Any input mutations will be converted into additional outputs in the graph, | |
meaning whoever calls this graph is responsible for applying the mutations | |
back to the original inputs. | |
(5) If is_joint is provided the graph will return parameter gradients in addition to user outputs. | |
The graph output will look like: | |
graph_outputs = (*updated_inputs, *user_outputs, *param_gradients) | |
There are also several restrictions on what modules can use this API. In particular: | |
(1) If trace_joint is specified, we expect the loss function to be **fused** | |
into the module forward. One of the outputs to the forward must be a scalar loss, | |
which is specified with `output_loss_index`. | |
All other outputs to the forward are presumed to not require gradients. | |
(2) This API cannot capture optimizers (although in theory we could build an API for this). | |
(3) Metadata mutations on params/buffers/inputs are banned. | |
(4) Data mutations on anything that requires gradients are banned (parameters) | |
(5) If an input is mutated, it is not allowed to alias any other inputs. | |
(6) Parameters must not be duplicated. | |
""" | |
named_parameters = dict(mod.named_parameters(remove_duplicate=False)) | |
named_buffers = dict(mod.named_buffers(remove_duplicate=False)) | |
params_and_buffers = { | |
**dict(named_parameters), | |
**dict(named_buffers), | |
} | |
params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) | |
params_and_buffers_flat = tuple(params_and_buffers_flat) | |
params_len = len(params_and_buffers_flat) | |
functional_call = create_functional_call(mod, params_spec, params_len) | |
num_fw_outs = None | |
if trace_joint: | |
# This helper effectively just adds some extra asserts about what the backward will look like: | |
# Outputs must include a scalar loss, that we compute gradients w.r.t. | |
# We don't compute gradients w.r.t. anything else: so just in case we detach() | |
# and other output tensors. | |
def fn_to_trace(*args): | |
nonlocal num_fw_outs | |
out = functional_call(*args) | |
if output_loss_index is None: | |
raise RuntimeError("""\ | |
If trace_joint=Trueit is required that one of your forward outputs must be a scalar loss. | |
You must specify the which (index) output is the loss with output_loss_index.""") | |
if isinstance(out, (torch.Tensor)): | |
out = (out,) | |
if not isinstance(out, (tuple, list)): | |
raise RuntimeError(f"Expected forward output to be either a tensor or a list/tuple of tensors. found {type(out)}") | |
for i, o in enumerate(out): | |
# We only want to create a backward graph w.r.t. the loss that the user passed in. | |
# This implies that every other output should not require gradients. | |
# Instead of making this an error (and forcing the user to detach all other outputs | |
# of their forward), | |
# we'll automatically detach them here. | |
if o.requires_grad and i != output_loss_index: | |
raise RuntimeError(f"""\ | |
Found an output of the forward that requires gradients, that was not the scalar loss. | |
We require all outputs to the forward that are not the scalar loss to not require gradient, | |
because we will only compute a backward graph against the scalar loss. | |
You can fix this by calling .detach() on each of your forward outputs that is not the loss. | |
You specified that output index {output_loss_index} is the loss, but we found that | |
the output at index {i} requires gradients.""") | |
out_loss = out[output_loss_index] | |
num_fw_outs = len(out) | |
if not out_loss.requires_grad: | |
raise RuntimeError(f"""\ | |
The output at index {output_loss_index} was marked as the loss, but it does not require gradients""") | |
if out_loss.numel() != 1: | |
raise RuntimeError(f"""\ | |
We require the output marked as the loss (at index {output_loss_index}) to be a scalar, but it has shape {out_loss.shape}""") | |
return out | |
ctx = nullcontext | |
else: | |
# Run under no_grad, so our tracing machinery only traces an inference graph. | |
ctx = torch.no_grad | |
fn_to_trace = functional_call | |
full_args = [] | |
# First, the params | |
# NB: It is REQUIRED that parameters come first, Inductor infers "fixed" | |
# parameters by looking at the difference in parameter count outside | |
# and inside AOTAutograd, and assumes the prefix of arguments are fixed | |
# arguments | |
full_args.extend(params_and_buffers_flat) | |
# Next, the input args | |
full_args.extend(args) | |
with ctx(): | |
fx_g, metadata, in_spec, out_spec = _aot_export_function( | |
fn_to_trace, | |
full_args, | |
decompositions=decompositions, | |
num_params_buffers=params_len, | |
no_tangents=True, | |
) | |
if trace_joint: | |
def flattened_joint(*args): | |
# The idea here is that the joint graph that AOTAutograd creates has some strict properties: | |
# (1) It accepts two arguments (primals, tangents), and pytree_flattens them | |
# (2) It returns a tuple of (fw_outs, gradients) | |
# This is a very useful convention for anyone who wants to partition the joint graph | |
# into a separate forward and backward graph. | |
# However, | |
# (1) for people exporting a single joint graph, it would be preferable not to have | |
# any pytrees in the graph. | |
# (2) We are guaranteed in the aot_export_module case that the forward outputs a loss, | |
# and there are therefore no tangents that are needed to run the joint graph. | |
# (3) AOTAutograd creates a grad_input for every input in the forward, | |
# including None's for inputs that are not grad-requiring tensors. | |
# we don't want these in our export graph. | |
# and there are therefore no tangents that are needed to run the joint graph. | |
# This function "fixes" both of the above by removing any tangent inputs, | |
# and removing pytrees from the original FX graph. | |
fake_tangents = [None for _ in range(metadata.num_outputs + metadata.num_mutated_inp_runtime_indices)] | |
fw_outs, gradients = fx_g(args, fake_tangents) | |
assert len(gradients) == len(args) | |
output_gradients = [] | |
for i, (a, grad) in enumerate(zip(args, gradients)): | |
if isinstance(a, torch.Tensor) and a.requires_grad: | |
assert grad is not None, """\ | |
Found a parameter that did not receive a gradient. | |
"This is most likely a bug, but if this needs to be supported please comment on this Github issue: | |
https://github.com/pytorch/pytorch/issues/101192 | |
""" | |
output_gradients.append(grad) | |
else: | |
assert grad is None | |
return *fw_outs, *output_gradients | |
fx_g = make_fx(flattened_joint)(*full_args) | |
user_args_flat = pytree.arg_tree_leaves(*args) | |
return fx_g, create_graph_signature( | |
fx_g, | |
metadata, | |
in_spec, | |
out_spec, | |
user_args_flat=user_args_flat, | |
params_and_buffers_flat=params_and_buffers_flat, | |
param_names=list(named_parameters.keys()), | |
buffer_names=list(named_buffers.keys()), | |
trace_joint=trace_joint, | |
num_user_fw_outs=num_fw_outs, | |
loss_index=output_loss_index, | |
) | |
def aot_export_joint_simple( | |
func: Callable, | |
args, | |
*, | |
trace_joint: bool, | |
# It looks like the main consequence of this API is that for dynamic shapes, | |
# it will assume that parms/buffers are static. | |
# With the new inferred dynamic shapes API, maybe this doesn't matter? | |
num_params_buffers: int = 0, | |
decompositions: Optional[Dict] = None, | |
) -> torch.fx.GraphModule: | |
""" | |
A simplified version of export. Used by higher order operators. | |
This function makes a high-level "no calling convention changes" guarantee: | |
- If no inputs require grad (so we export an inference graph), | |
there are *no* calling convention change between the exported graph, and "func". | |
- If at least one input requires grad (so we trace out and export a joint fw-bw graph), | |
Then if you were partition the graph into a separate forward and backward graph, | |
The forward graph will have no calling convention changes compared to "func". | |
The above also relies on some strong restrictions around which functions this API accepts: | |
(1) `args` cannot contain any pytrees (they must have been pytree_flattened already) | |
(2) `func` cannot mutate any inputs | |
(3) The outputs of `func` cannot alias any inputs. | |
Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops. | |
""" | |
if trace_joint: | |
ctx = nullcontext | |
else: | |
# Run under no_grad, so our tracing machinery only traces an inference graph. | |
ctx = torch.no_grad | |
with ctx(): | |
fx_g, metadata, in_spec, out_spec = _aot_export_function( | |
func, | |
args, | |
decompositions=decompositions, | |
) | |
# At this point, we can just directly return the (joint or inference graph) that we traced. | |
# First though: a bunch of assertions to make sure that our graph doesn't require | |
# any calling convention changes compared to the original function. | |
# These restrictions are *in addition to* the general restrictions on export. | |
# No input mutations | |
if len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) != 0: | |
raise RuntimeError(f"aot_export_joint_simple does not support input mutations. {str(metadata)}") | |
# No output aliasing | |
if len([x for x in metadata.output_info if x.output_type != OutputType.non_alias]) != 0: | |
raise RuntimeError(f"aot_export_joint_simple does not support outputs that alias inputs. {str(metadata)}") | |
# No pytrees | |
if type(in_spec) == pytree.LeafSpec: | |
raise RuntimeError(f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}") | |
if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0: | |
raise RuntimeError(f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}") | |
if type(out_spec) == pytree.LeafSpec: | |
raise RuntimeError(f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}") | |
if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0: | |
raise RuntimeError(f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}") | |
# TODO: we might have to temporarily patch config.functionalize_rng | |
# so that it doesn't run when we're exporting a higher order op. | |
if config.debug_assert: | |
# Smoke test that after partitioning, we can run the forward without any calling convention changes. | |
fw_module, bw_module = aot_config.default_partition( | |
fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) | |
) | |
# Attempt to run the fw_module with the original user inputs | |
fake_mode = detect_fake_mode(args) | |
if fake_mode is None: | |
fake_mode = FakeTensorMode() | |
with fake_mode: | |
fw_module(*args) | |
return fx_g | |
# Private for now because we aren't providing a contract on what to return | |
# for joint graphs (we could when there's a clearer use case) | |
# In the future, we may need to add more export API's that provide their own strong guarantees. | |
# This is meant as a general helper function for handling various export-y use cases. | |
def _aot_export_function( | |
func: Callable, | |
args, | |
*, | |
num_params_buffers: int = 0, | |
decompositions: Optional[Dict] = None, | |
# If we're exporting a joint graph and we don't want any tangent inputs in the graph | |
# (because we are backpropping through a scalar 1 loss), | |
# we need to explicitly specify not to include tangents in the graph. | |
# It's not enough just to check that our tangent is a scalar, since we also | |
# need to know if it is a 1 (no need to make it a graph input), or something else | |
# (requiring it to be a graph input). | |
# We don't know this info at trace time though, so we need to make it an explicit config. | |
no_tangents: bool = False, | |
) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]: | |
dynamic_shapes = False | |
for x in args: | |
if isinstance(x, FakeTensor): | |
dynamic_shapes = x.fake_mode.shape_env is not None | |
break | |
flat_fn, out_spec = create_tree_flattened_fn(func, args) | |
flat_args, in_spec = pytree.tree_flatten(args) | |
# The export use case doesn't care about several bits of AOTConfig | |
# (1) compilers (we just export the graph) | |
# (2) partitioners (export is only full graph, user can partition themselves) | |
aot_config = AOTConfig( | |
fw_compiler=None, | |
bw_compiler=None, | |
inference_compiler=None, | |
partition_fn=None, | |
decompositions=decompositions, | |
num_params_buffers=num_params_buffers, | |
aot_id=next(AOT_COUNTER), | |
# For now there's no use case involving keeping input mutations in the graph | |
# (which we can only do in the inference case anyway). | |
# We can add this later if we need to. | |
keep_inference_input_mutations=False, | |
dynamic_shapes=dynamic_shapes, | |
aot_autograd_arg_pos_to_source=None, | |
is_export=True, | |
no_tangents=no_tangents, | |
) | |
fx_g, meta = create_aot_dispatcher_function( | |
flat_fn, | |
flat_args, | |
aot_config, | |
) | |
return fx_g, meta, in_spec, out_spec.spec | |
compiled_function = aot_function | |
compiled_module = aot_module | |