|
import inspect |
|
from typing import Dict, List, Union |
|
|
|
from torch import _C |
|
from torch.onnx import _constants |
|
from torch.onnx._internal import registration |
|
|
|
|
|
class _TorchSchema: |
|
def __init__(self, schema: Union[_C.FunctionSchema, str]) -> None: |
|
if isinstance(schema, _C.FunctionSchema): |
|
self.name: str = schema.name |
|
self.overload_name: str = schema.overload_name |
|
self.arguments: List[str] = [arg.name for arg in schema.arguments] |
|
self.optional_arguments: List[str] = [] |
|
self.returns: List[str] = [ret.name for ret in schema.returns] |
|
self.opsets: List[int] = [] |
|
else: |
|
self.name = schema |
|
self.overload_name = "" |
|
self.arguments = [] |
|
self.optional_arguments = [] |
|
self.returns = [] |
|
self.opsets = [] |
|
|
|
def __str__(self) -> str: |
|
s = ( |
|
f"{self.name}.{self.overload_name}(" |
|
+ ", ".join(self.arguments) |
|
+ ") -> (" |
|
+ ", ".join(self.returns) |
|
+ ")" |
|
+ " in opsets " |
|
+ ", ".join(str(opset) for opset in self.opsets) |
|
) |
|
return s |
|
|
|
def __hash__(self): |
|
|
|
return hash(self.name) |
|
|
|
def __eq__(self, other) -> bool: |
|
if not isinstance(other, _TorchSchema): |
|
return False |
|
|
|
return self.name == other.name |
|
|
|
def is_aten(self) -> bool: |
|
return self.name.startswith("aten::") |
|
|
|
def is_backward(self) -> bool: |
|
return "backward" in self.name |
|
|
|
|
|
def _symbolic_argument_count(func): |
|
params = [] |
|
signature = inspect.signature(func) |
|
optional_params = [] |
|
for name, parameter in signature.parameters.items(): |
|
if name in {"_outputs", "g"}: |
|
continue |
|
if parameter.default is parameter.empty: |
|
optional_params.append(parameter) |
|
else: |
|
params.append(str(parameter)) |
|
return params |
|
|
|
|
|
def all_forward_schemas() -> Dict[str, _TorchSchema]: |
|
"""Returns schemas for all TorchScript forward ops.""" |
|
torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()] |
|
return {schema.name: schema for schema in torch_schemas if not schema.is_backward()} |
|
|
|
|
|
def all_symbolics_schemas() -> Dict[str, _TorchSchema]: |
|
"""Returns schemas for all onnx supported ops.""" |
|
symbolics_schemas = {} |
|
|
|
for name in registration.registry.all_functions(): |
|
func_group = registration.registry.get_function_group(name) |
|
assert func_group is not None |
|
symbolics_schema = _TorchSchema(name) |
|
func = func_group.get(_constants.ONNX_MAX_OPSET) |
|
if func is not None: |
|
symbolics_schema.arguments = _symbolic_argument_count(func) |
|
symbolics_schema.opsets = list( |
|
range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1) |
|
) |
|
else: |
|
|
|
func = func_group.get(7) |
|
symbolics_schema.arguments = _symbolic_argument_count(func) |
|
symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET)) |
|
|
|
symbolics_schemas[name] = symbolics_schema |
|
|
|
return symbolics_schemas |
|
|