Spaces:
Runtime error
Runtime error
import copy | |
import dataclasses | |
import functools | |
import io | |
import json | |
import pathlib | |
import re | |
import sys | |
import types | |
import warnings | |
import weakref | |
import zipfile | |
from collections import OrderedDict | |
from contextlib import contextmanager | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
from unittest.mock import patch | |
import sympy | |
import torch | |
import torch._dynamo | |
import torch.fx | |
import torch.fx._pytree as fx_pytree | |
import torch.utils._pytree as pytree | |
from torch._decomp import core_aten_decompositions, get_decompositions | |
from torch._dispatch.python import enable_python_dispatcher | |
from torch._dynamo.exc import UserError, UserErrorType | |
from torch._dynamo.source import ConstantSource | |
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass | |
from torch._functorch.aot_autograd import aot_export_module, GraphSignature | |
from torch._functorch.eager_transforms import functionalize | |
from torch._guards import detect_fake_mode | |
from torch._ops import OpOverload | |
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode | |
from torch.export import _create_constraint, _Dim, Constraint | |
from torch.export.exported_program import ( | |
ExportedProgram, | |
ModuleCallEntry, | |
ModuleCallSignature, | |
_disable_prexisiting_fake_mode, | |
) | |
from torch.export.graph_signature import ( | |
_sig_to_specs, | |
ArgumentSpec, | |
ConstantArgument, | |
ExportGraphSignature, | |
InputKind, | |
InputSpec, | |
OutputKind, | |
OutputSpec, | |
SymIntArgument, | |
TensorArgument, | |
) | |
from torch.fx import traceback as fx_traceback | |
from torch.fx._compatibility import compatibility | |
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode | |
from torch.fx.experimental.symbolic_shapes import ( | |
ConstraintViolationError, | |
GuardOnDataDependentSymNode, | |
ShapeEnv, | |
StrictMinMaxConstraint, | |
) | |
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo | |
from torch.utils._sympy.value_ranges import ValueRangeError, ValueRanges | |
from .exported_program import ( | |
_create_stateful_graph_module, | |
_process_constraints, | |
CallSpec, | |
) | |
from .passes.add_runtime_assertions_for_constraints_pass import ( | |
_AddRuntimeAssertionsForInlineConstraintsPass, | |
) | |
from .passes.lift_constant_tensor_pass import lift_constant_tensor_pass | |
from .passes.remove_runtime_assertions import _RemoveRuntimeAssertionsPass | |
from .passes.replace_sym_size_ops_pass import _replace_sym_size_ops_pass | |
from .passes.replace_view_ops_with_view_copy_ops_pass import ( | |
ReplaceViewOpsWithViewCopyOpsPass, | |
) | |
from .wrappers import _wrap_submodules | |
def _process_dynamic_shapes( | |
f: Callable, | |
args: Tuple[Any, ...], | |
kwargs: Optional[Dict[str, Any]] = None, | |
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, | |
) -> Optional[List[Constraint]]: | |
if dynamic_shapes is None or len(dynamic_shapes) == 0: | |
return None | |
kwargs = kwargs if kwargs is not None else {} | |
from collections.abc import Mapping, Sequence | |
def tree_zip(combined_args, dynamic_shapes): | |
if isinstance(combined_args, (tuple, list)): | |
if not isinstance(dynamic_shapes, Sequence): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, " | |
f"got {dynamic_shapes} instead", | |
) | |
if len(combined_args) != len(dynamic_shapes): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected {dynamic_shapes} to have {len(combined_args)} items", | |
) | |
for i, shape in enumerate(dynamic_shapes): | |
yield from tree_zip(combined_args[i], shape) | |
elif isinstance(combined_args, dict): | |
if not isinstance(dynamic_shapes, Mapping): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, " | |
f"got {dynamic_shapes} instead", | |
) | |
if len(combined_args) != len(dynamic_shapes): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected {dynamic_shapes} to have {len(combined_args)} items", | |
) | |
for k, shape in dynamic_shapes.items(): | |
yield from tree_zip(combined_args[k], shape) | |
elif dataclasses.is_dataclass(combined_args): | |
if not type(dynamic_shapes) == type(combined_args): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected dynamic_shapes of a {type(combined_args)} to be a {type(combined_args)}, " | |
f"got {dynamic_shapes} instead", | |
) | |
for f in dataclasses.fields(combined_args): | |
yield from tree_zip(getattr(combined_args, f.name), getattr(dynamic_shapes, f.name)) | |
elif isinstance(combined_args, torch.Tensor): | |
yield (combined_args, dynamic_shapes) | |
else: | |
if dynamic_shapes is not None: | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected dynamic_shapes of a {type(combined_args)} to be None, " | |
f"got {dynamic_shapes} instead", | |
) | |
def to_constraint(dim, tensor, i): | |
constraint = dynamic_dim(tensor, i, debug_name=dim.__name__) | |
if dim.min != 2: | |
constraint = constraint >= dim.min | |
if dim.max != sys.maxsize - 1: | |
constraint = constraint <= dim.max | |
return constraint | |
from collections import defaultdict | |
symbols = defaultdict(list) | |
bounds: Dict[str, Tuple[int, int]] = {} | |
def check_same_bounds(dim): | |
if dim.__name__ in symbols: | |
min_, max_ = bounds[dim.__name__] | |
if dim.min != min_ or dim.max != max_: | |
this_ = _Dim.readable(dim.__name__, min_, max_) | |
that_ = _Dim.readable(dim.__name__, dim.min, dim.max) | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Found different definitions {this_} and {that_} " | |
f"for the same symbolic dimension {dim}!" | |
) | |
else: | |
bounds[dim.__name__] = (dim.min, dim.max) | |
def update_symbols(tensor, shape): | |
if isinstance(shape, dict): | |
for i, dim in shape.items(): | |
if isinstance(dim, _Dim): | |
check_same_bounds(dim) | |
symbols[dim.__name__].append(to_constraint(dim, tensor, i)) | |
else: | |
if dim is not None: | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " | |
"try None instead", | |
) | |
elif isinstance(shape, (tuple, list)): | |
for i, dim in enumerate(shape): | |
if isinstance(dim, _Dim): | |
check_same_bounds(dim) | |
symbols[dim.__name__].append(to_constraint(dim, tensor, i)) | |
else: | |
if dim is not None: | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " | |
"try None instead", | |
) | |
else: | |
if shape is not None: | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Unexpected dynamic_shape {shape} of Tensor, " | |
"try None instead", | |
) | |
import inspect | |
if isinstance(f, ExportedProgram): | |
f = f.module() | |
signature = inspect.signature(f.forward) if isinstance(f, torch.nn.Module) else inspect.signature(f) | |
combined_args = signature.bind(*args, **kwargs).arguments | |
# This means user didn't specify dynamic shapes with argument names. | |
combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values()) # type: ignore[assignment] | |
for tensor, shape in tree_zip(combined_args, dynamic_shapes): | |
update_symbols(tensor, shape) | |
constraints = [] | |
for dynamic_dims in symbols.values(): | |
primary, *others = dynamic_dims | |
if others: | |
for other in others: | |
constraints.append(primary == other) | |
else: | |
constraints.append(primary) | |
return constraints | |
def export__RC__( | |
f: Callable, | |
args: Tuple[Any, ...], | |
kwargs: Optional[Dict[str, Any]] = None, | |
*, | |
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, | |
strict: bool = True, | |
preserve_module_call_signature: Tuple[str, ...] = (), | |
) -> ExportedProgram: | |
""" | |
API for exporting with dynamic shape specifications instead of constraints. | |
It should be considered "release candidate" (RC), meant to replace `export`. | |
Here, `dynamic_shapes` is expected to be a dict from | |
argument names of `f` to dynamic shape specifications OR a tuple where each element | |
corresponds to the original order of the arguments defined in the function signature | |
,as follows: | |
- The dynamic shape of a tensor argument can be specified as: | |
- Either a dict from dynamic dimension indices to Dim types. It is not | |
required to include static dimension indices in this dict, but when | |
they are, they should be mapped to None. | |
- Or a tuple of Dim types or None. The Dim types correspond to dynamic | |
dimensions, whereas static dimensions are denoted by None. | |
- Arguments that are dicts or tuples of tensors are recursively specified | |
by using mappings or sequences of contained specifications. | |
See `export` for documentation of `f`, `args`, `kwargs` and return. | |
""" | |
constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes) | |
return _export( | |
f, | |
args, | |
kwargs, | |
constraints=constraints, | |
strict=strict, | |
preserve_module_call_signature=preserve_module_call_signature | |
) | |
def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): | |
if not isinstance(t, torch.Tensor): | |
raise UserError( | |
UserErrorType.DYNAMIC_DIM, | |
f"Expected tensor as input to dynamic_dim but got {type(t)}" | |
) | |
if t.dim() < 1: | |
raise UserError( | |
UserErrorType.DYNAMIC_DIM, | |
"Cannot mark 0-dimension tensors to be dynamic" | |
) | |
if index >= t.dim(): | |
raise UserError( | |
UserErrorType.DYNAMIC_DIM, | |
f"Expected the dimension passed to dynamic_dim to be in the range [0:{t.dim()-1}]" | |
f" but got {index}, which is out of bounds for the given tensor." | |
) | |
return _create_constraint( | |
weakref.ref(t), | |
id(t), | |
index, | |
StrictMinMaxConstraint( | |
vr=ValueRanges(lower=2, upper=sympy.oo), warn_only=False | |
), | |
debug_name=debug_name, | |
) | |
class ExportDynamoConfig: | |
""" | |
Manage Export-specific configurations of Dynamo. | |
""" | |
allow_rnn: bool = True | |
DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig() | |
DECOMP_TABLE = core_aten_decompositions() | |
# TODO(zhxchen17) This is not needed if we output pre_dispatch graph upfront from export(). | |
def _disable_decomp_table(): | |
global DECOMP_TABLE | |
prev, DECOMP_TABLE = DECOMP_TABLE, {} | |
try: | |
yield | |
finally: | |
DECOMP_TABLE = prev | |
def capture_pre_autograd_graph( | |
f: Callable, | |
args: Tuple[Any], | |
kwargs: Optional[Dict[str, Any]] = None, | |
constraints: Optional[List[Constraint]] = None, | |
) -> torch.nn.Module: | |
""" | |
A helper function that is intended to trace a module before any pre-autograd | |
decomposition is run. The produced module will be "non-functional" and | |
composed of aten operators. Later this API will be deleted in favor of more general | |
torch.export API. | |
Args: | |
f: A callable to be traced | |
args: example positional inputs. | |
kwargs: optional example keyword inputs. | |
constraints: A optional list of constraints on the dynamic arguments specifying | |
their possible range of their shapes | |
Returns: | |
An nn.Module containing the traced method. | |
""" | |
decomp_table = { | |
torch.ops.aten.dropout.default: torch.ops.aten.dropout.default.decompose, | |
torch.ops.aten.batch_norm.default: torch.ops.aten.batch_norm.default.decompose, | |
torch.ops.aten._batch_norm_impl_index.default: torch.ops.aten._batch_norm_impl_index.default.decompose, | |
torch.ops.aten.native_batch_norm.default: torch.ops.aten.native_batch_norm.default.decompose, | |
} | |
if kwargs is None: | |
kwargs = {} | |
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): | |
m = torch._dynamo.export( | |
f, | |
constraints=constraints, | |
assume_static_by_default=True, | |
tracing_mode="symbolic", | |
decomposition_table=decomp_table, | |
pre_dispatch=True, | |
aten_graph=True, | |
)( | |
*args, | |
**kwargs, | |
)[0] | |
def _train(self, mode: bool = True): | |
raise NotImplementedError("Calling train() is not supported yet.") | |
def _eval(self, mode: bool = True): | |
raise NotImplementedError("Calling eval() is not supported yet.") | |
_, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs) | |
m.meta["inline_constraints"] = { | |
k: v | |
for k, v in fake_mode.shape_env.runtime_var_to_range.items() | |
if re.match(r"^[if]\d+$", str(k)) | |
} | |
flat_args, _ = pytree.tree_flatten((args, kwargs or {})) | |
range_constraints, equality_constraints = _process_constraints(m, 0, flat_args) | |
unlifted_m = _create_stateful_graph_module( | |
m, | |
range_constraints=range_constraints, | |
equality_constraints=equality_constraints, | |
) | |
unlifted_m.train = types.MethodType(_train, m) # type: ignore[method-assign] | |
unlifted_m.eval = types.MethodType(_eval, m) # type: ignore[method-assign] | |
return unlifted_m | |
def _convert_input_to_fake(gm, args, kwargs): | |
if len(args) == 0 and len(kwargs) == 0 and len(dict(gm.named_parameters())) == 0 and len(dict(gm.named_buffers())) == 0: | |
return [], {}, {}, None | |
fake_inps: List[torch.Tensor] = [] | |
fake_mode = None | |
for node in gm.graph.nodes: | |
if node.op == "placeholder" and "val" in node.meta: | |
fake_val = node.meta["val"] | |
if fake_val is not None and isinstance(fake_val, torch.Tensor): | |
fake_inps.append(fake_val) | |
if detected_fake_mode := detect_fake_mode(fake_inps): | |
fake_mode = detected_fake_mode | |
assert fake_mode is not None, "Cannot find fake_mode attatched to the graph's placeholders." | |
count = 0 | |
def convert_to_fake(x): | |
nonlocal count | |
val = fake_inps[count] | |
count += 1 | |
return val | |
fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args) | |
# TODO properly use the cached fake tensor | |
fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs) | |
fake_params_buffers = pytree.tree_map_only(torch.Tensor, | |
functools.partial(fake_mode.from_tensor, static_shapes=True), | |
{**dict(gm.named_parameters(remove_duplicate=False)), | |
**dict(gm.named_buffers(remove_duplicate=False))}) | |
return fake_args, fake_kwargs, fake_params_buffers, fake_mode | |
def _replace_param_buffer_names(param_buffer_table, sig): | |
for spec in sig.input_specs: | |
spec.target = param_buffer_table.get(spec.target, spec.target) | |
for spec in sig.output_specs: | |
spec.target = param_buffer_table.get(spec.target, spec.target) | |
def _normalize_nn_module_stack(gm_torch_level, root_cls): | |
# Append a root module to every nn_module_stack. | |
root = "L['self']" | |
root_key = re.sub(r'[^a-zA-Z0-9]', '_', root) | |
for gm in gm_torch_level.modules(): | |
if not isinstance(gm, torch.fx.GraphModule): | |
continue | |
for node in gm.graph.nodes: | |
if node.op in ["placeholder", "output"]: | |
continue | |
add_root = True | |
if nn_module_stack := node.meta.get("nn_module_stack", {}): | |
path, ty = next(iter(nn_module_stack.values())) | |
assert issubclass(ty, torch.nn.Module) | |
# TODO Figure out why sometimes we have root sometimes we don't. | |
if path == root and ty is root_cls: | |
add_root = False | |
if add_root: | |
def normalize_path(path): | |
try: | |
parts = [] | |
class Path: | |
def __getattr__(self, name): | |
parts.append(name) | |
return self | |
def __getitem__(self, idx): | |
parts.append(str(idx)) | |
return self | |
eval(path, {"L": {"self": Path()}}) | |
return ".".join(parts) | |
except Exception: # TODO(zhxchen17) Remove this. | |
return path | |
nn_module_stack = {root_key: (root, root_cls), **nn_module_stack} | |
node.meta["nn_module_stack"] = { | |
key: (normalize_path(path), ty) | |
for key, (path, ty) in nn_module_stack.items() | |
} | |
def _export_to_torch_ir( | |
f: Callable, | |
args: Tuple[Any, ...], | |
kwargs: Optional[Dict[str, Any]] = None, | |
constraints: Optional[List[Constraint]] = None, | |
*, | |
preserve_module_call_signature: Tuple[str, ...] = (), | |
disable_constraint_solver: bool = False, | |
) -> torch.fx.GraphModule: | |
""" | |
Traces either an nn.Module's forward function or just a callable with PyTorch | |
operations inside and produce a torch.fx.GraphModule in torch IR. | |
""" | |
constraints = constraints or [] | |
kwargs = kwargs or {} | |
if not isinstance(args, tuple): | |
raise UserError(UserErrorType.INVALID_INPUT, | |
f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}") | |
# We convert to nn.Module because __call__ of ExportedProgram | |
# is untracable right now. | |
if isinstance(f, ExportedProgram): | |
f = f.module() | |
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): | |
try: | |
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} | |
with _wrap_submodules(f, preserve_module_call_signature, module_call_specs): | |
gm_torch_level, _ = torch._dynamo.export( | |
f, | |
constraints=constraints, | |
assume_static_by_default=True, | |
tracing_mode="symbolic", | |
disable_constraint_solver=disable_constraint_solver, | |
)( | |
*args, | |
**kwargs, | |
) | |
except (ConstraintViolationError, ValueRangeError) as e: | |
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 | |
except GuardOnDataDependentSymNode as e: | |
raise UserError( # noqa: TRY200 | |
UserErrorType.ANTI_PATTERN, | |
f"Consider annotating your code using torch._constrain_as_*(). {str(e)}", | |
case_name="constrain_as_size_example", | |
) | |
gm_torch_level.meta["module_call_specs"] = module_call_specs | |
return gm_torch_level | |
def export( | |
f: Callable, | |
args: Tuple[Any, ...], | |
kwargs: Optional[Dict[str, Any]] = None, | |
constraints: Optional[List[Constraint]] = None, | |
*, | |
strict: bool = True, | |
preserve_module_call_signature: Tuple[str, ...] = (), | |
) -> ExportedProgram: | |
if constraints is not None: | |
warnings.warn( | |
"Using `constraints` to specify dynamic shapes for export is DEPRECATED " | |
"and will not be supported in the future. " | |
"Please use `dynamic_shapes` instead (see docs on `torch.export.export`).", | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
return _export( | |
f, | |
args, | |
kwargs, | |
constraints, | |
strict=strict, | |
preserve_module_call_signature=preserve_module_call_signature, | |
) | |
def _unlift_user_inputs_to_buffers( | |
gm_torch_level: torch.fx.GraphModule, | |
aot_export_args | |
) -> List[str]: | |
flat_args = pytree.tree_leaves(aot_export_args) | |
user_input_names = [] | |
with gm_torch_level.graph.inserting_before(): | |
for i, (arg, node) in enumerate(zip(flat_args, gm_torch_level.graph.nodes)): | |
assert node.op == "placeholder" | |
user_input_names.append(node.name) | |
if isinstance(arg, torch.Tensor): | |
assert not hasattr(gm_torch_level, node.name) | |
gm_torch_level.register_buffer(node.name, arg) | |
get_attr = gm_torch_level.graph.get_attr(node.name) | |
node.replace_all_uses_with(get_attr) | |
get_attr.meta = copy.copy(node.meta) | |
for node in list(gm_torch_level.graph.nodes): | |
if node.op == "placeholder": | |
assert len(node.users) == 0 | |
gm_torch_level.graph.erase_node(node) | |
gm_torch_level.recompile() | |
return user_input_names | |
def _lift_buffers_to_user_inputs( | |
gm: torch.fx.GraphModule, | |
graph_signature: GraphSignature, | |
user_input_names: List[str] | |
) -> Dict[str, str]: | |
assert len(graph_signature.user_inputs) == 0 | |
assert graph_signature.backward_signature is None | |
names = set(user_input_names) | |
placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] | |
# user inputs are always added in the end | |
start = len(graph_signature.parameters) | |
end = start + len(graph_signature.buffers) | |
buffer_nodes = placeholders[start:end] | |
last_placeholder_node = placeholders[-1] if len(placeholders) > 0 else None | |
old_nodes: Dict[str, torch.fx.Node] = {} | |
for node in buffer_nodes: | |
buffer_name = graph_signature.inputs_to_buffers[node.name] | |
if buffer_name not in names: | |
continue | |
old_nodes[buffer_name] = node | |
replaces = {} | |
new_node_names: Dict[str, str] = {} | |
with gm.graph.inserting_after(last_placeholder_node): | |
for name in reversed(user_input_names): | |
new_node = gm.graph.placeholder(name) | |
new_node.target = new_node.name | |
new_node_names[name] = new_node.name | |
if name in old_nodes: | |
old_node = old_nodes[name] | |
new_node.meta = copy.copy(old_node.meta) | |
old_node.replace_all_uses_with(new_node) | |
replaces[old_node.name] = new_node.name | |
new_node_names = dict(reversed(new_node_names.items())) | |
for old_node in old_nodes.values(): | |
gm.graph.erase_node(old_node) | |
gm.recompile() | |
graph_signature.buffers = [b for b in graph_signature.buffers if b not in names] | |
graph_signature.inputs_to_buffers = { | |
i: b for i, b in graph_signature.inputs_to_buffers.items() if b not in names | |
} | |
user_inputs_to_mutate = { | |
o: b for o, b in graph_signature.buffers_to_mutate.items() if b in names | |
} | |
graph_signature.buffers_to_mutate = { | |
o: b for o, b in graph_signature.buffers_to_mutate.items() if b not in names | |
} | |
graph_signature.user_inputs.extend(new_node_names.values()) # type: ignore[arg-type] | |
graph_signature.user_outputs = [ | |
replaces[o] if o in replaces else o for o in graph_signature.user_outputs | |
] | |
return user_inputs_to_mutate # type: ignore[return-value] | |
def _export_non_strict( | |
mod, | |
fake_args, | |
fake_kwargs, | |
fake_params_buffers, | |
*, | |
transform=lambda x: x # TODO(zhxchen17) Revisit if this is needed later. | |
): | |
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, | |
# otherwise aot_export_module will error out because it sees a mix of fake_modes. | |
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. | |
with torch.nn.utils.stateless._reparametrize_module(mod, fake_params_buffers): | |
gm, graph_signature = transform(aot_export_module)( | |
mod, | |
(*fake_args, *fake_kwargs.values()), | |
trace_joint=False | |
) | |
# NOTE: aot_export adds symint metadata for placeholders with int values; | |
# since these become specialized, we replace such metadata with the original values | |
flat_args = pytree.tree_leaves((fake_args, fake_kwargs)) | |
index = 0 | |
total_param_buffers = len(graph_signature.parameters) + len(graph_signature.buffers) | |
for node in gm.graph.nodes: | |
if node.op == "placeholder": | |
if index >= total_param_buffers: | |
user_arg = flat_args[index - total_param_buffers] | |
if not isinstance(user_arg, torch.Tensor): | |
node.meta["val"] = user_arg | |
index += 1 | |
is_joint = graph_signature.backward_signature is not None | |
def make_argument_spec(node) -> ArgumentSpec: | |
assert "val" in node.meta, f"{node} has no 'val' metadata field" | |
val = node.meta["val"] | |
if isinstance(val, FakeTensor): | |
return TensorArgument(name=node.name) | |
elif isinstance(val, torch.SymInt): | |
return SymIntArgument(name=node.name) | |
else: | |
return ConstantArgument(value=val) | |
input_specs, output_specs = _sig_to_specs( | |
user_inputs=set(graph_signature.user_inputs), | |
inputs_to_parameters=graph_signature.inputs_to_parameters, # type: ignore[arg-type] | |
inputs_to_buffers=graph_signature.inputs_to_buffers, # type: ignore[arg-type] | |
user_outputs=set(graph_signature.user_outputs), # type: ignore[arg-type] | |
buffer_mutations=graph_signature.buffers_to_mutate, # type: ignore[arg-type] | |
user_input_mutations=gm.meta.get("user_inputs_to_mutate", {}), # type: ignore[arg-type] | |
grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {}, # type: ignore[arg-type, union-attr] | |
grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}, # type: ignore[arg-type, union-attr] | |
loss_output=graph_signature.backward_signature.loss_output if is_joint else None, # type: ignore[arg-type, union-attr] | |
inputs=[make_argument_spec(node) for node in gm.graph.nodes if node.op == "placeholder"], | |
outputs=[make_argument_spec(node) for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)], | |
) | |
export_graph_signature = ExportGraphSignature(input_specs=input_specs, output_specs=output_specs) | |
tensor_constants = lift_constant_tensor_pass(gm, export_graph_signature) | |
class _ExportedProgramNonStrict: | |
gm: torch.fx.GraphModule | |
sig: ExportGraphSignature | |
tensor_constants: Dict[str, torch.Tensor] | |
return _ExportedProgramNonStrict( | |
gm, | |
export_graph_signature, | |
tensor_constants, | |
) | |
def _get_params_buffers(mod: torch.nn.Module) -> Dict[str, torch.Tensor]: | |
params_buffers: Dict[str, torch.Tensor] = {} | |
for name, param in mod.named_parameters(remove_duplicate=False): | |
params_buffers[name] = param | |
for name, buffer in mod.named_buffers(remove_duplicate=False): | |
params_buffers[name] = buffer | |
return params_buffers | |
def _export( | |
f: Callable, | |
args: Tuple[Any, ...], | |
kwargs: Optional[Dict[str, Any]] = None, | |
constraints: Optional[List[Constraint]] = None, | |
*, | |
strict: bool = True, | |
preserve_module_call_signature: Tuple[str, ...] = (), | |
) -> ExportedProgram: | |
""" | |
Traces either an nn.Module's forward function or just a callable with PyTorch | |
operations inside and produce a ExportedProgram. | |
Args: | |
m: the `nn.Module` or callable to trace. | |
args: example positional inputs. | |
kwargs: optional example keyword inputs. | |
constraints: A optional list of constraints on the dynamic arguments specifying | |
their possible range of their shapes | |
preserve_module_call_signature: A list of submodule paths for which the original | |
calling conventions are preserved as metadata. | |
Returns: | |
An ExportedProgram containing the traced method. | |
""" | |
constraints = constraints or [] | |
kwargs = kwargs or {} | |
if not strict: | |
assert isinstance(f, torch.nn.Module) | |
assert len(preserve_module_call_signature) == 0 | |
assert len(constraints) == 0, "dynamic shape NYI" | |
assert len(kwargs) == 0, "keyword arguments NYI" | |
out_spec = None | |
def _tuplify_outputs(aot_export): | |
def _aot_export_non_strict(mod, args, **kwargs): | |
class Wrapper(torch.nn.Module): | |
def __init__(self, mod): | |
super().__init__() | |
self._export_root = mod | |
def forward(self, *args, **kwargs): | |
nonlocal out_spec | |
flat_outs, out_spec = pytree.tree_flatten(self._export_root(*args, **kwargs)) | |
return tuple(flat_outs) | |
gm, sig = aot_export(Wrapper(mod), args, **kwargs) | |
def strip_root(x): | |
return x[len('_export_root.'):] if x.startswith('_export_root.') else x | |
sig.parameters = pytree.tree_map(strip_root, sig.parameters) | |
sig.buffers = pytree.tree_map(strip_root, sig.buffers) | |
sig.inputs_to_buffers = pytree.tree_map(strip_root, sig.inputs_to_buffers) | |
sig.inputs_to_parameters = pytree.tree_map(strip_root, sig.inputs_to_parameters) | |
sig.buffers_to_mutate = pytree.tree_map(strip_root, sig.buffers_to_mutate) | |
return gm, sig | |
return _aot_export_non_strict | |
ep_non_strict = _export_non_strict(f, args, {}, f.state_dict(), transform=_tuplify_outputs) | |
assert out_spec is not None | |
return ExportedProgram( | |
ep_non_strict.gm, | |
ep_non_strict.gm.graph, | |
ep_non_strict.sig, | |
_get_params_buffers(f), | |
{}, | |
[], | |
[ModuleCallEntry("", ModuleCallSignature([], [], pytree.tree_flatten((args, {}))[1], out_spec))], | |
(args, kwargs), | |
tensor_constants=ep_non_strict.tensor_constants, | |
) | |
gm_torch_level = _export_to_torch_ir( | |
f, | |
args, | |
kwargs, | |
constraints, | |
preserve_module_call_signature=preserve_module_call_signature, | |
) | |
params_buffers = _get_params_buffers(gm_torch_level) | |
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo. | |
fake_args, fake_kwargs, fake_params_buffers, dynamo_fake_mode = _convert_input_to_fake(gm_torch_level, args, kwargs) | |
# First, we want to pass through the graph to try populating | |
# val field for getattr if there is anything missing. | |
# THis can happen when quantization adds extra params and forgets | |
# to update "val" | |
for node in gm_torch_level.graph.nodes: | |
if node.op == "get_attr" and "val" not in node.meta: | |
attr = getattr(gm_torch_level, node.target) | |
# Checks if it is not a HigherOrderOp branch or a module | |
if not isinstance(attr, torch.nn.Module): | |
assert dynamo_fake_mode is not None, ( | |
"Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." | |
) | |
node.meta["val"] = dynamo_fake_mode.from_tensor(attr, static_shapes=True) | |
# When aot_export lifts the params, we lose the nn_module_stack | |
# and source_fn from the param nodes as they are treated as fresh inputs | |
# Therefore, we manually extract them before calling into aot_export | |
params_buffers_to_node_meta = {} | |
for node in gm_torch_level.graph.nodes: | |
target = node.target | |
meta = node.meta | |
if node.op == "call_module": | |
submodule = getattr(gm_torch_level, target) | |
if isinstance(submodule, torch.nn.Module): | |
for name, _ in submodule.named_parameters(recurse=True, remove_duplicate=False): | |
params_buffers_to_node_meta[target + "." + name] = meta | |
for name, _ in submodule.named_buffers(recurse=True, remove_duplicate=False): | |
params_buffers_to_node_meta[target + "." + name] = meta | |
if node.op == "get_attr": | |
submodule = getattr(gm_torch_level, target) | |
if not isinstance(submodule, torch.fx.GraphModule): | |
params_buffers_to_node_meta[target] = meta | |
# If the call_function uses param as input, we also need to update params' meta | |
# with this call_function node's meta. | |
# This is basically the same flow as torch.fx.traceback.preserve_meta() | |
if node.op == "call_function" and not isinstance(node.target, torch._ops.HigherOrderOperator): | |
for arg in node._input_nodes: | |
if arg.op == "get_attr": | |
for entry in torch.fx.proxy._COPY_META_FIELDS: | |
if entry in meta: | |
params_buffers_to_node_meta[arg.target][entry] = meta[entry] | |
# Fix the graph output signature to be tuple if scalar | |
out_spec = orig_out_spec = gm_torch_level._out_spec | |
assert out_spec is not None | |
# aot_export expect the return type to always be a tuple. | |
if out_spec.type not in (list, tuple): | |
out_spec = pytree.TreeSpec(tuple, None, [out_spec]) | |
orig_args = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] | |
gm_torch_level.graph._codegen = _PyTreeCodeGen( | |
_PyTreeInfo( | |
orig_args, | |
gm_torch_level._in_spec, | |
out_spec, | |
) | |
) | |
gm_torch_level.recompile() | |
param_buffer_table: Dict[str, str] = {} | |
if isinstance(f, torch.nn.Module): | |
param_lookup: Dict[int, List[str]] = {} | |
buffer_lookup: Dict[int, List[str]] = {} | |
for name, param in f.named_parameters(remove_duplicate=False): | |
param_lookup.setdefault(id(param), []).append(name) | |
for name, buffer in f.named_buffers(remove_duplicate=False): | |
buffer_lookup.setdefault(id(buffer), []).append(name) | |
for dynamo_name, dynamo_param in gm_torch_level.named_parameters(remove_duplicate=False): | |
assert dynamo_name not in param_buffer_table | |
if id(dynamo_param) in param_lookup: | |
param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)].pop() | |
for dynamo_name, dynamo_buffer in gm_torch_level.named_buffers(remove_duplicate=False): | |
assert dynamo_name not in param_buffer_table | |
if id(dynamo_buffer) in buffer_lookup: | |
param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop() | |
if isinstance(f, torch.nn.Module): | |
_normalize_nn_module_stack(gm_torch_level, type(f)) | |
def _process_user_inputs(aot_export): | |
def _aot_export_strict(gm_torch_level: torch.fx.GraphModule, args, **kwargs): | |
user_input_names = _unlift_user_inputs_to_buffers(gm_torch_level, args) | |
gm, graph_signature = aot_export(gm_torch_level, (), **kwargs) | |
user_inputs_to_mutate = _lift_buffers_to_user_inputs(gm, graph_signature, user_input_names) | |
# TODO unfortunately preserving graph-level metadata is not | |
# working well with aot_export. So we manually copy it. | |
# (The node-level meta is addressed above.) | |
gm.meta.update(gm_torch_level.meta) | |
assert "user_inputs_to_mutate" not in gm.meta | |
gm.meta["user_inputs_to_mutate"] = user_inputs_to_mutate | |
return gm, graph_signature | |
return _aot_export_strict | |
# Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict | |
# to follow the order in orig_args and correctly call module | |
ep_non_strict = _export_non_strict( | |
gm_torch_level, | |
fake_args, | |
_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs), | |
fake_params_buffers, | |
transform=_process_user_inputs | |
) | |
gm = ep_non_strict.gm | |
export_graph_signature = ep_non_strict.sig | |
tensor_constants = ep_non_strict.tensor_constants | |
# After aot_export, set the param/buffer metadata back into placeholders | |
# Technically, users can still construct this data from param names | |
# without relying on this metadata | |
for node in gm.graph.nodes: | |
if node.op == "placeholder": | |
if node.target in export_graph_signature.inputs_to_parameters: | |
param_name = export_graph_signature.inputs_to_parameters[node.target] | |
if param_name in params_buffers_to_node_meta: | |
for k, v in params_buffers_to_node_meta[param_name].items(): | |
node.meta[k] = v | |
if node.target in export_graph_signature.inputs_to_buffers: | |
buffer_name = export_graph_signature.inputs_to_buffers[node.target] | |
if buffer_name in params_buffers_to_node_meta: | |
for k, v in params_buffers_to_node_meta[buffer_name].items(): | |
node.meta[k] = v | |
# The unbacked symint symbols are updated in aot_export | |
# so we serialize them here instead of inside dynamo | |
# dynamo_fake_mode can be None if there's no placeholder in gm_torch_level | |
if dynamo_fake_mode: | |
gm.meta["inline_constraints"] = { | |
k: v | |
for k, v in dynamo_fake_mode.shape_env.runtime_var_to_range.items() | |
if re.match(r"^[if]\d+$", str(k)) | |
} | |
num_lifted = next( | |
(i for i, s in enumerate(export_graph_signature.input_specs) if s.kind == InputKind.USER_INPUT), 0 | |
) | |
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs)) | |
range_constraints, equality_constraints = _process_constraints( | |
gm, | |
num_lifted, | |
flat_args, | |
) | |
if isinstance(f, torch.nn.Module): | |
_replace_param_buffer_names(param_buffer_table, export_graph_signature) | |
params_buffers = {param_buffer_table.get(name, name): tensor for name, tensor in params_buffers.items()} | |
module_call_signatures = { | |
fqn: ModuleCallSignature(inputs=[], outputs=[], **specs) | |
for fqn, specs in gm_torch_level.meta["module_call_specs"].items() | |
} | |
if len(preserve_module_call_signature) > 0: | |
res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm) | |
assert res is not None | |
gm = res.graph_module | |
assert orig_out_spec is not None | |
exported_program = ExportedProgram( | |
gm, | |
gm.graph, | |
export_graph_signature, | |
# TODO(zhxchen17) Return empty state_dict for functions. | |
params_buffers, | |
range_constraints, | |
equality_constraints, | |
[ModuleCallEntry("", ModuleCallSignature(inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=orig_out_spec))] + | |
[ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()], | |
(args, kwargs), | |
tensor_constants=tensor_constants, | |
) | |
if len(range_constraints) > 0 or len(equality_constraints) > 0: | |
exported_program = exported_program._transform( | |
_AddRuntimeAssertionsForInlineConstraintsPass(range_constraints, equality_constraints) | |
) | |
return exported_program | |
def _reorder_kwargs_by_names(arg_names: List[str], args: Tuple[Any], kwargs: Dict[str, Any]): | |
assert len(arg_names) == len(args) + len(kwargs), ( | |
f"Total number of arg names is expected to be {len(arg_names)} " | |
f"but got {len(args)} positional args, {len(kwargs)} kwargs." | |
) | |
return {kw_name: kwargs[kw_name] for kw_name in arg_names[len(args):]} | |
def save( | |
ep: ExportedProgram, | |
f: Union[str, pathlib.Path, io.BytesIO], | |
*, | |
extra_files: Optional[Dict[str, Any]] = None, | |
opset_version: Optional[Dict[str, int]] = None, | |
) -> None: | |
from .serde.serialize import serialize, SerializedArtifact | |
from .serde.schema import SCHEMA_VERSION | |
artifact: SerializedArtifact = serialize(ep, opset_version) | |
if isinstance(f, (str, pathlib.Path)): | |
f = str(f) | |
with zipfile.ZipFile(f, 'w') as zipf: | |
# Save every field the SerializedArtifact to a file | |
for field in dataclasses.fields(artifact): | |
field_name = field.name | |
serialized_field = getattr(artifact, field_name) | |
zipf.writestr(f"serialized_{field_name}.json", serialized_field) | |
zipf.writestr('version', str(SCHEMA_VERSION)) | |
# Add extra files if provided | |
if extra_files: | |
for extra_file_name, content in extra_files.items(): | |
encoded_content = content.encode('utf-8') | |
zipf.writestr(f"extra_files/{extra_file_name}", encoded_content) | |
def load( | |
f: Union[str, pathlib.Path, io.BytesIO], | |
*, | |
extra_files: Optional[Dict[str, Any]] = None, | |
expected_opset_version: Optional[Dict[str, int]] = None, | |
) -> ExportedProgram: | |
if isinstance(f, (str, pathlib.Path)): | |
f = str(f) | |
with zipfile.ZipFile(f, 'r') as zipf: | |
# Check the version | |
version = int(zipf.read('version')) | |
from .serde.schema import SCHEMA_VERSION | |
if version != SCHEMA_VERSION: | |
raise RuntimeError( | |
f"Serialized version {version} does not match our current " | |
f"schema version {SCHEMA_VERSION}." | |
) | |
from .serde.serialize import deserialize, SerializedArtifact | |
# Load serialized_ep and serialized_state_dict from the zip file | |
artifact: SerializedArtifact = SerializedArtifact( | |
**{ | |
field.name: zipf.read(f"serialized_{field.name}.json") | |
for field in dataclasses.fields(SerializedArtifact) | |
} | |
) | |
# Deserialize ExportedProgram | |
ep = deserialize(artifact) | |
# Populate extra_files map | |
if extra_files is not None: | |
for filename in extra_files.keys(): | |
extra_files[filename] = zipf.read(f"extra_files/{filename}").decode('utf-8') | |
return ep | |
def aot_compile( | |
f: Callable, | |
args: Tuple[Any], | |
kwargs: Optional[Dict[str, Any]] = None, | |
*, | |
constraints: Optional[List[Constraint]] = None, | |
dynamic_shapes: Optional[Dict[str, Any]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
remove_runtime_assertions: bool = False, | |
disable_constraint_solver: bool = False, | |
) -> str: | |
""" | |
Note: this function is not stable yet | |
Traces either an nn.Module's forward function or just a callable with PyTorch | |
operations inside, generates executable cpp code from the program, and returns | |
the path to the generated shared library | |
Args: | |
f: the `nn.Module` or callable to trace. | |
args: example positional inputs. | |
kwargs: optional example keyword inputs. | |
constraints: A optional list of constraints on the dynamic arguments specifying | |
their possible range of their shapes | |
dynamic_shapes: An experimental new feature designed to subsume ``constraints``. | |
A dict mapping argument names of ``f`` to their dynamic shape | |
specifications, as follows. Dynamic shape specifications can be a | |
dict from dynamic dimensions to ``Dim`` types, or a tuple/list of | |
``Optional[Dim]`` corresponding to each input dimension. | |
options: A dictionary of options to control inductor | |
disable_constraint_solver: Whether the dim constraint solver must be disabled. | |
Returns: | |
Path to the generated shared library | |
""" | |
if constraints is not None: | |
warnings.warn( | |
"The constraints field is deprecated. " | |
"Please use dynamic_shapes instead." | |
) | |
from torch._inductor.decomposition import select_decomp_table | |
if constraints is None: | |
constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes) | |
# We want to export to Torch IR here to utilize the pre_grad passes in | |
# inductor, which run on Torch IR. | |
gm = _export_to_torch_ir( | |
f, | |
args, | |
kwargs, | |
constraints, | |
disable_constraint_solver=disable_constraint_solver | |
) | |
flat_example_inputs = pytree.arg_tree_leaves(*args, **(kwargs or {})) | |
with torch.no_grad(): | |
so_path = torch._inductor.aot_compile(gm, flat_example_inputs, options) # type: ignore[arg-type] | |
return so_path | |