|
import abc |
|
import typing as t |
|
|
|
import torch |
|
import torch.fx |
|
from torch.fx._compatibility import compatibility |
|
from .shape_prop import TensorMetadata |
|
from .tools_common import get_node_target, CALLABLE_NODE_OPS |
|
|
|
|
|
__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports'] |
|
|
|
|
|
TargetTypeName = str |
|
|
|
|
|
SupportedArgumentDTypes = t.Optional[ |
|
t.Tuple[ |
|
t.Sequence[t.Sequence[torch.dtype]], |
|
t.Dict[str, t.Sequence[torch.dtype]], |
|
] |
|
] |
|
|
|
SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes] |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
class OperatorSupportBase(abc.ABC): |
|
"""Interface for determining if a fx.Node is supported by a backend""" |
|
@abc.abstractmethod |
|
def is_node_supported( |
|
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node |
|
) -> bool: |
|
raise NotImplementedError() |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
class OperatorSupport(OperatorSupportBase): |
|
""" |
|
`_support_dict` maps node.target typename to supported inputs dtypes. |
|
|
|
node.target typename is retrieved using helper function `get_node_target()` |
|
|
|
If supported inputs dtypes is None, it means any dtype is supported, else |
|
we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}). |
|
|
|
The first tuple ([dtypes], ...) indicates what dtypes are supported for |
|
inputs in node.args and the second dict {"name": [dtypes], ...} indicates |
|
what dtypes are supported for inputs in node.kwargs. |
|
|
|
For inputs in args, if we don't want to check it, we can put None there, |
|
e.g. (None, [torch.float]) indicates that we don't care about the type of |
|
the first input in args. And for inputs in kwargs, if not listed, will not |
|
be checked. |
|
""" |
|
|
|
_support_dict: SupportDict |
|
|
|
def __init__( |
|
self, |
|
support_dict: t.Optional[SupportDict] = None |
|
): |
|
self._support_dict = support_dict or {} |
|
|
|
def is_node_supported( |
|
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node |
|
) -> bool: |
|
""" |
|
Args: |
|
`sumodules`: mapping from module name to the module. This can be |
|
retrieved by calling model.named_modules(). |
|
|
|
`node`: a Fx node that we want to determine whether it's supported. |
|
|
|
Returns: |
|
`is_supported`: whether the arg `node` is supported. |
|
""" |
|
if node.op not in CALLABLE_NODE_OPS: |
|
return True |
|
|
|
target = get_node_target(submodules, node) |
|
|
|
|
|
if target not in self._support_dict: |
|
return False |
|
|
|
|
|
if self._support_dict[target] is None: |
|
return True |
|
|
|
args_dtypes, kwargs_dtypes = self._support_dict[target] |
|
|
|
|
|
for i, dtypes in enumerate(args_dtypes): |
|
if len(node.args) <= i: |
|
break |
|
|
|
|
|
if dtypes is None: |
|
continue |
|
|
|
|
|
if not isinstance(node.args[i], torch.fx.Node): |
|
continue |
|
|
|
arg_dtype = _get_arg_dtype(node.args[i]) |
|
if arg_dtype not in dtypes: |
|
return False |
|
|
|
|
|
for k, dtypes in kwargs_dtypes.items(): |
|
if k not in node.kwargs: |
|
continue |
|
|
|
|
|
if not isinstance(node.kwargs[k], torch.fx.Node): |
|
continue |
|
|
|
kwarg_dtype = _get_arg_dtype(node.kwargs[k]) |
|
if kwarg_dtype not in dtypes: |
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool] |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase: |
|
"""Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance |
|
|
|
`IsNodeSupported` has the same call signature as |
|
`OperatorSupportBase.is_node_supported` |
|
""" |
|
class FunctionalOperatorSupport(OperatorSupportBase): |
|
def is_node_supported( |
|
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node |
|
) -> bool: |
|
return is_node_supported(submodules, node) |
|
return FunctionalOperatorSupport() |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: |
|
"""Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` |
|
instance by evaluating each input `OperatorSupportBase` instance, and returns False if |
|
any of it reports False. |
|
""" |
|
def _chain(submods, node) -> bool: |
|
return all( |
|
x.is_node_supported(submods, node) |
|
for x in op_support |
|
) |
|
return create_op_support(_chain) |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
class OpSupports: |
|
"""A set of atomic `OperatorSupportBase` instances that can be combined together |
|
to form more complex operator support logic. |
|
""" |
|
@classmethod |
|
def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase: |
|
"""Report a node as non-supported, if any of its arguments is of dtype""" |
|
|
|
def _decline_if_input_dtype( |
|
submodules: t.Mapping[str, torch.nn.Module], |
|
node: torch.fx.Node, |
|
) -> bool: |
|
for arg in node.all_input_nodes: |
|
|
|
if arg.op == "get_attr": |
|
continue |
|
arg_dtype = _get_arg_dtype(arg) |
|
if arg_dtype == dtype: |
|
return False |
|
return True |
|
return create_op_support(_decline_if_input_dtype) |
|
|
|
@classmethod |
|
def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase: |
|
""" |
|
If a node has a name that is in the disallow set, reported it as non-supported. |
|
""" |
|
def _decline_if_node_in_names( |
|
submodules: t.Mapping[str, torch.nn.Module], |
|
node: torch.fx.Node, |
|
) -> bool: |
|
if node.name in disallow_set: |
|
return False |
|
else: |
|
return True |
|
return create_op_support(_decline_if_node_in_names) |
|
|
|
|
|
def _get_arg_dtype(arg: torch.fx.Node) -> t.Any: |
|
assert isinstance(arg, torch.fx.Node) |
|
tensor_meta = arg.meta.get("tensor_meta") |
|
dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"] |
|
return dtype |
|
|