|
import torch |
|
import inspect |
|
import numbers |
|
import types |
|
import typing |
|
import enum |
|
import warnings |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING |
|
from torch._jit_internal import boolean_dispatched |
|
from ._compatibility import compatibility |
|
from torch._ops import OpOverloadPacket, OpOverload |
|
|
|
if TYPE_CHECKING: |
|
from .node import Argument |
|
|
|
@compatibility(is_backward_compatible=False) |
|
class ArgsKwargsPair(NamedTuple): |
|
""" |
|
Simple named tuple for wrapping args/kwargs pairs. |
|
""" |
|
args: Tuple[Any, ...] |
|
kwargs: Dict[str, Any] |
|
|
|
_manual_overrides : Dict[Callable, List[inspect.Signature]] = {} |
|
|
|
def _nonzero_schemas(): |
|
signatures = [] |
|
|
|
def nonzero(self): |
|
pass |
|
signatures.append(inspect.signature(nonzero)) |
|
|
|
def nonzero(self, *, as_tuple : bool): |
|
pass |
|
signatures.append(inspect.signature(nonzero)) |
|
|
|
return signatures |
|
|
|
_manual_overrides[torch.nonzero] = _nonzero_schemas() |
|
|
|
class _FakeGlobalNamespace: |
|
def __getattr__(self, name): |
|
if name == 'torch': |
|
return torch |
|
raise RuntimeError('Expected a torch namespace lookup') |
|
|
|
_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout, |
|
'number' : numbers.Number, 'Future' : torch.jit.Future, |
|
'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme, |
|
'__torch__': _FakeGlobalNamespace(), 'NoneType': type(None), |
|
't': typing.TypeVar('t')} |
|
for k in dir(typing): |
|
_type_eval_globals[k] = getattr(typing, k) |
|
|
|
def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: |
|
""" |
|
Convert a TorchScript type to a Python type (including subtypes) via |
|
eval'ing the annotation_str. _type_eval_globals sets up expressions |
|
like "List" and "Future" to map to actual types (typing.List and jit.Future) |
|
""" |
|
return eval(ts_type.annotation_str, _type_eval_globals) |
|
|
|
def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: |
|
parameters : List[inspect.Parameter] = [] |
|
for arg in ts_schema.arguments: |
|
arg_type = _torchscript_type_to_python_type(arg.type) |
|
default = arg.default_value if arg.has_default_value() else inspect.Parameter.empty |
|
|
|
|
|
|
|
|
|
|
|
name = arg.name if arg.name != 'self' else 'input' |
|
kind = inspect.Parameter.KEYWORD_ONLY if arg.kwarg_only else inspect.Parameter.POSITIONAL_OR_KEYWORD |
|
parameters.append(inspect.Parameter(name=name, kind=kind, default=default, annotation=arg_type)) |
|
return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns] |
|
if len(return_types) == 0: |
|
return_type = None |
|
elif len(return_types) == 1: |
|
return_type = return_types[0] |
|
else: |
|
return_type = tuple(return_types) |
|
|
|
return inspect.Signature(parameters, return_annotation=return_type) |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): |
|
signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) |
|
|
|
if signatures and schemas: |
|
matched_schemas = [] |
|
|
|
|
|
|
|
|
|
for candidate_signature, schema in zip(signatures, schemas): |
|
try: |
|
candidate_signature.bind(*args, **kwargs) |
|
matched_schemas.append((candidate_signature, schema)) |
|
except TypeError as e: |
|
continue |
|
|
|
def throw_if_mutable(schema): |
|
if schema.is_mutable: |
|
raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' |
|
f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' |
|
f'are not supported') |
|
|
|
if len(matched_schemas) == 0: |
|
|
|
pass |
|
elif len(matched_schemas) == 1: |
|
|
|
_, schema_to_check = matched_schemas[0] |
|
throw_if_mutable(schema_to_check) |
|
pass |
|
else: |
|
|
|
|
|
pass |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): |
|
""" |
|
Given an operator on the `torch` namespace, return a list of `inspect.Signature` |
|
objects corresponding to the overloads of that op.. May return `None` if a signature |
|
could not be retrieved. |
|
|
|
Args: |
|
op (Callable): An operator on the `torch` namespace to look up a signature for |
|
|
|
Returns: |
|
Optional[List[inspect.Signature]]: A list of signatures for the overloads of this |
|
operator, or None if the operator signatures could not be retrieved. If |
|
return_schemas=True, returns a tuple containing the optional Python signatures |
|
and the optional TorchScript Function signature |
|
""" |
|
if isinstance(op, OpOverload): |
|
schemas = [op._schema] |
|
elif isinstance(op, OpOverloadPacket): |
|
schemas = [getattr(op, overload)._schema for overload in op.overloads()] |
|
else: |
|
override = _manual_overrides.get(op) |
|
if override: |
|
return (override, None) if return_schemas else None |
|
|
|
aten_fn = torch.jit._builtins._find_builtin(op) |
|
|
|
if aten_fn is None: |
|
return (None, None) if return_schemas else None |
|
schemas = torch._C._jit_get_schemas_for_operator(aten_fn) |
|
|
|
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] |
|
return (signatures, schemas) if return_schemas else signatures |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def create_type_hint(x): |
|
try: |
|
if isinstance(x, list) or isinstance(x, tuple): |
|
|
|
if isinstance(x, list): |
|
def ret_type(x): |
|
return List[x] |
|
else: |
|
def ret_type(x): |
|
return Tuple[x, ...] |
|
if len(x) == 0: |
|
return ret_type(Any) |
|
base_type = x[0] |
|
for t in x: |
|
if issubclass(t, base_type): |
|
continue |
|
elif issubclass(base_type, t): |
|
base_type = t |
|
else: |
|
return ret_type(Any) |
|
return ret_type(base_type) |
|
except Exception as e: |
|
|
|
warnings.warn(f"We were not able to successfully create type hint from the type {x}") |
|
pass |
|
return x |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def type_matches(signature_type : Any, argument_type : Any): |
|
sig_origin_type = getattr(signature_type, '__origin__', signature_type) |
|
|
|
if signature_type is argument_type: |
|
return True |
|
|
|
|
|
|
|
if sig_origin_type is typing.Union and signature_type != argument_type: |
|
sig_contained = signature_type.__args__ |
|
return any(type_matches(c, argument_type) for c in sig_contained) |
|
|
|
if signature_type is List[int] and argument_type is int: |
|
|
|
return True |
|
|
|
if getattr(signature_type, '__origin__', None) in {list, List}: |
|
sig_el_type = signature_type.__args__[0] |
|
if not inspect.isclass(sig_el_type): |
|
warnings.warn( |
|
f"Does not support nested parametric types, got {signature_type}. Please file a bug.") |
|
return False |
|
if getattr(argument_type, '__origin__', None) in {list, List}: |
|
return issubclass(argument_type.__args__[0], sig_el_type) |
|
|
|
def is_homogeneous_tuple(t): |
|
if not getattr(t, '__origin__', None) in {tuple, Tuple}: |
|
return False |
|
contained = t.__args__ |
|
if t.__args__ == ((),): |
|
return True |
|
return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) |
|
|
|
|
|
return is_homogeneous_tuple(argument_type) |
|
|
|
|
|
if signature_type is int and argument_type is torch.dtype: |
|
return True |
|
|
|
if signature_type is numbers.Number and argument_type in {int, float}: |
|
return True |
|
if inspect.isclass(argument_type) and inspect.isclass(signature_type): |
|
return issubclass(argument_type, signature_type) |
|
|
|
return False |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def normalize_function( |
|
target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, |
|
kwarg_types : Optional[Dict[str, Any]] = None, |
|
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: |
|
""" |
|
Returns normalized arguments to PyTorch functions. This means that |
|
`args/kwargs` will be matched up to the functional's |
|
signature and return exclusively kwargs in positional order if |
|
`normalize_to_only_use_kwargs` is True. |
|
Also populates default values. Does not support positional-only |
|
parameters or varargs parameters (*args, **kwargs). Does not support modules. |
|
|
|
May require `arg_types` and `kwarg_types` in order to disambiguate overloads. |
|
|
|
Args: |
|
target (Callable): Function that we are normalizing |
|
args (Tuple[Any]): Tuple of args to the function |
|
kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function |
|
arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args |
|
kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs |
|
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. |
|
|
|
Returns: |
|
|
|
Returns normalized_args_and_kwargs, or `None` if not successful. |
|
""" |
|
if kwargs is None: |
|
kwargs = {} |
|
new_args_and_kwargs = None |
|
if not isinstance(target, types.BuiltinFunctionType) and not ( |
|
isinstance(target, OpOverloadPacket) or isinstance(target, OpOverload) |
|
): |
|
target_for_analysis = target |
|
if target in boolean_dispatched: |
|
|
|
|
|
|
|
|
|
assert not isinstance(target, str) |
|
dispatched = boolean_dispatched[target] |
|
if_true, if_false = dispatched['if_true'], dispatched['if_false'] |
|
if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters: |
|
return None |
|
target_for_analysis = if_true |
|
|
|
assert callable(target_for_analysis) |
|
sig = inspect.signature(inspect.unwrap(target_for_analysis)) |
|
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs) |
|
else: |
|
assert callable(target) |
|
torch_op_schemas = get_signature_for_torch_op(target) |
|
matched_schemas = [] |
|
if torch_op_schemas: |
|
|
|
|
|
|
|
for candidate_signature in torch_op_schemas: |
|
try: |
|
candidate_signature.bind(*args, **kwargs) |
|
matched_schemas.append(candidate_signature) |
|
except TypeError as e: |
|
continue |
|
|
|
if len(matched_schemas) == 0: |
|
|
|
pass |
|
elif len(matched_schemas) == 1: |
|
|
|
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs, |
|
normalize_to_only_use_kwargs) |
|
else: |
|
if arg_types is not None or kwarg_types is not None: |
|
arg_types = arg_types if arg_types else cast(Tuple[Any], ()) |
|
kwarg_types = kwarg_types if kwarg_types else {} |
|
for candidate_signature in torch_op_schemas: |
|
sig_matches = True |
|
try: |
|
bound_types = candidate_signature.bind(*arg_types, **kwarg_types) |
|
for arg_name, arg_type in bound_types.arguments.items(): |
|
param = candidate_signature.parameters[arg_name] |
|
sig_matches = sig_matches and type_matches(param.annotation, arg_type) |
|
except TypeError as e: |
|
sig_matches = False |
|
if sig_matches: |
|
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs, |
|
normalize_to_only_use_kwargs) |
|
break |
|
else: |
|
|
|
|
|
schema_printouts = '\n'.join(str(schema) for schema in matched_schemas) |
|
raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but ' |
|
f'the schema match was ambiguous! Please provide argument types to ' |
|
f'the normalize_arguments() call. Available schemas:\n{schema_printouts}') |
|
|
|
return new_args_and_kwargs |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def normalize_module( |
|
root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, |
|
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: |
|
""" |
|
Returns normalized arguments to PyTorch modules. This means that |
|
`args/kwargs` will be matched up to the functional's |
|
signature and return exclusively kwargs in positional order if |
|
`normalize_to_only_use_kwargs` is True. |
|
Also populates default values. Does not support positional-only |
|
parameters or varargs parameters (*args, **kwargs). |
|
|
|
Args: |
|
root (nn.Module): root module upon which we query modules |
|
target (Callable): Function that we are normalizing |
|
args (Tuple[Any]): Tuple of args to the function |
|
kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function |
|
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. |
|
|
|
Returns: |
|
|
|
Returns normalized_args_and_kwargs, or `None` if not successful. |
|
""" |
|
try: |
|
submod = root.get_submodule(target) |
|
except AttributeError: |
|
raise RuntimeError(f"Tried to normalize node with target {target} but root did not " |
|
f"have that target!") |
|
if hasattr(submod.__class__, '__name__'): |
|
classname = submod.__class__.__name__ |
|
if getattr(torch.nn, classname, None) == submod.__class__: |
|
sig = inspect.signature(inspect.unwrap(submod.forward)) |
|
if kwargs is None: |
|
kwargs = {} |
|
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, |
|
normalize_to_only_use_kwargs) |
|
return new_args_and_kwargs |
|
return None |
|
|
|
def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...], |
|
kwargs : Dict[str, Any], |
|
normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]: |
|
""" |
|
Given a call target, args, and kwargs, return the arguments normalized into |
|
an ArgsKwargsPair, or None if the type signature is not supported by |
|
this normalization. |
|
|
|
Args: |
|
|
|
target (inspect.Signature): Signature object for the target |
|
args (Tuple): Arguments that appear at the callsite for `target` |
|
kwargs (Dict): Keyword arguments that appear at the callsite for `target` |
|
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. |
|
|
|
Returns: |
|
|
|
Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if |
|
this target is not supported. |
|
""" |
|
|
|
|
|
|
|
supported_parameter_types = { |
|
inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} |
|
if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): |
|
return None |
|
|
|
bound_args = sig.bind(*args, **kwargs) |
|
bound_args.apply_defaults() |
|
|
|
new_kwargs : Dict[str, Any] = {} |
|
new_args : List[Any] = [] |
|
for i, param in enumerate(sig.parameters): |
|
if not normalize_to_only_use_kwargs and i < len(args): |
|
new_args.append(bound_args.arguments[param]) |
|
else: |
|
new_kwargs[param] = bound_args.arguments[param] |
|
|
|
return ArgsKwargsPair(tuple(new_args), new_kwargs) |
|
|