import torch |
from torch import Tensor |
aten = torch.ops.aten |
from typing import Optional, List, Dict, Set |
import inspect |
from torch.fx.operator_schemas import get_signature_for_torch_op |
import warnings |
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {} |
function_name_set: Set[str] = set() |
def check_decomposition_has_type_annotations(f): |
inspect_empty = inspect._empty |
sig = inspect.signature(f) |
for param in sig.parameters.values(): |
assert param.annotation != inspect_empty, \ |
"No signature on param {name} for function {func}".format(name=param.name, func=f.name) |
assert sig.return_annotation != inspect_empty, "No return annotation for function {func}".format(func=f.name) |
def signatures_match(decomposition_sig, torch_op_sig): |
decomp_params = decomposition_sig.parameters |
op_params = torch_op_sig.parameters |
if len(decomp_params) != len(op_params): |
return False |
for decomp_param, op_param in zip(decomp_params.values(), op_params.values()): |
inspect_empty = inspect._empty |
for field in ['name', 'annotation']: |
if field == 'name' and decomp_param.name == "self": |
warnings.warn("PyTorch uses 'input' instead of 'self' on public api") |
if getattr(decomp_param, field) != getattr(op_param, field): |
return False |
decomp_default = decomp_param.default |
op_default = op_param.default |
if decomp_default != inspect_empty and op_default != inspect_empty: |
if decomp_default != op_default: |
return False |
return decomposition_sig.return_annotation == torch_op_sig.return_annotation |
def register_decomposition(aten_op, registry=None): |
def decomposition_decorator(f): |
nonlocal registry |
if registry is None: |
registry = decomposition_table |
check_decomposition_has_type_annotations(f) |
torch_op_sigs, torch_op_schemas = get_signature_for_torch_op(aten_op, return_schemas=True) |
decomposition_sig = inspect.signature(f) |
found_index = None |
for i, torch_op_sig in enumerate(torch_op_sigs): |
if signatures_match(decomposition_sig, torch_op_sig): |
found_index = i |
break |
assert found_index is not None, "Could not find matching signature: " + str(f) |
assert f.__name__ not in function_name_set, "Duplicated function name {}".format(f.__name__) |
function_name_set.add(f.__name__) |
scripted_func = torch.jit.script(f) |
torch._C._jit_pass_inline(scripted_func.graph) |
for _ in range(2): |
torch._C._jit_pass_peephole(scripted_func.graph) |
torch._C._jit_pass_constant_propagation(scripted_func.graph) |
registry[str(torch_op_schemas[found_index])] = scripted_func |
return f |
return decomposition_decorator |
@register_decomposition(aten.var) |
def var_decomposition(input: Tensor, dim: Optional[List[int]] = None, correction: Optional[int] = None, |
keepdim: bool = False) -> Tensor: |
if dim is None: |
dim_i: List[int] = [] |
dim = dim_i |
if isinstance(dim, (tuple, list)) and len(dim) == 0: |
n = input.numel() |
else: |
n = 1 |
for dim_i in dim: |
n *= input.shape[dim_i] |
mean = aten.mean(input, dim, True) |
sub = input - mean |
sq = sub * sub |
sum = aten.sum(sq, dim, keepdim) |
if correction is not None: |
n = n - correction |
return sum / n |
@register_decomposition(aten.var) |
def var(input: Tensor, unbiased: bool = True) -> Tensor: |
return var_decomposition(input, correction=(1 if unbiased else 0)) |