File size: 4,056 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from __future__ import annotations
import contextlib
import functools
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
import torchgen.local as local
from torchgen.model import (
BackendIndex,
DispatchKey,
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
)
from torchgen.utils import context, S, T
if TYPE_CHECKING:
from collections.abc import Iterator
# Helper functions for defining generators on things in the model
F = TypeVar(
"F",
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
Union[NativeFunction, NativeFunctionsGroup],
Union[NativeFunction, NativeFunctionsViewGroup],
)
F2 = TypeVar(
"F2",
NativeFunction,
NativeFunctionsGroup,
Optional[NativeFunction],
bool,
str,
)
F3 = TypeVar("F3", tuple[NativeFunction, Any], list[NativeFunction])
@contextlib.contextmanager
def native_function_manager(
g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
) -> Iterator[None]:
if isinstance(g, NativeFunctionsGroup):
# By default, we associate all errors with structured native functions
# with the out variant. In some cases, it might be better to have
# a more specific place to hang things; if so, use
# native_function_manager again on the inside
f = g.out
elif isinstance(g, NativeFunctionsViewGroup):
# We associate errors with the view operator
f = g.view
else:
f = g
with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
with local.parametrize(
use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
use_ilistref_for_tensor_lists=f.part_of_structured_group,
):
yield
# Given a function that operates on NativeFunction, wrap it into a new function
# that sets some appropriate context managers for that native function.
# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
# (you will get an error if we try to access the local variables without having
# set them).
def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
@functools.wraps(func)
def wrapper(f: F) -> T:
with native_function_manager(f):
return func(f)
return wrapper
def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
@functools.wraps(func)
def wrapper(f: F, f2: F2) -> T:
# The first native_function is assumed to be the one with the appropriate context.
with native_function_manager(f):
return func(f, f2)
return wrapper
def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
@functools.wraps(func)
def wrapper(slf: S, f: F) -> T:
with native_function_manager(f):
return func(slf, f)
return wrapper
def method_with_nested_native_function(
func: Callable[[S, F3], T],
) -> Callable[[S, F3], T]:
@functools.wraps(func)
def wrapper(slf: S, f: F3) -> T:
with native_function_manager(f[0]):
return func(slf, f)
return wrapper
# Convenience decorator for functions that explicitly take in a BackendIndex,
# instead of indirectly taking one in as a closure
def with_native_function_and_index(
func: Callable[[F, BackendIndex], T],
) -> Callable[[F, BackendIndex], T]:
@functools.wraps(func)
def wrapper(f: F, backend_index: BackendIndex) -> T:
with native_function_manager(f):
return func(f, backend_index)
return wrapper
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
def with_native_function_and_indices(
func: Callable[[F, dict[DispatchKey, BackendIndex]], T],
) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
@functools.wraps(func)
def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
with native_function_manager(f):
return func(f, backend_indices)
return wrapper
|