|
import torch |
|
from torch._prims_common import ( |
|
Number, |
|
NumberType, |
|
TensorLike, |
|
TensorLikeType, |
|
ELEMENTWISE_TYPE_PROMOTION_KIND, |
|
) |
|
import torch._prims_common as utils |
|
from torch.utils._pytree import tree_flatten, tree_unflatten |
|
|
|
from typing import Callable, Sequence, Union, Tuple, NamedTuple |
|
import inspect |
|
from functools import wraps, reduce |
|
import operator |
|
import warnings |
|
from itertools import chain |
|
|
|
|
|
def _maybe_convert_to_dtype( |
|
a: Union[TensorLikeType, NumberType, Sequence, None], dtype: torch.dtype |
|
) -> Union[TensorLikeType, NumberType, Sequence, None]: |
|
import torch._prims as prims |
|
if isinstance(a, TensorLike): |
|
if a.dtype != dtype: |
|
|
|
|
|
return prims.convert_element_type(a, dtype) |
|
return a |
|
if isinstance(a, Number): |
|
return utils.dtype_to_type_ctor(dtype)(a) |
|
if isinstance(a, Sequence): |
|
return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) |
|
|
|
|
|
if a is None: |
|
return None |
|
|
|
raise ValueError( |
|
"Received type {0} that is neither a tensor or a number!".format(type(a)) |
|
) |
|
|
|
|
|
def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: |
|
if not isinstance(a, Number): |
|
msg = "Found unknown type {0} when trying to convert scalars!".format(type(a)) |
|
raise ValueError(msg) |
|
if not utils.is_weakly_lesser_type(type(a), typ): |
|
msg = "Scalar {0} of type {1} cannot be safely cast to type {2}!".format( |
|
a, type(a), typ |
|
) |
|
raise ValueError(msg) |
|
|
|
return typ(a) |
|
|
|
|
|
def _annotation_has_type(*, typ, annotation): |
|
if hasattr(annotation, "__args__"): |
|
for a in annotation.__args__: |
|
if _annotation_has_type(typ=typ, annotation=a): |
|
return True |
|
return False |
|
|
|
return typ is annotation |
|
|
|
|
|
class elementwise_type_promotion_wrapper(object): |
|
""" |
|
Adds elementwise type promotion to a Python reference implementation. |
|
|
|
Takes two kwargs, type_promoting_args and type_promotion_kind. |
|
|
|
type_promoting_args must be a string Sequence specifiying the argument names of all |
|
arguments that participate in type promotion (and should be type promoted). If the |
|
arg specifies a Sequence-type then every element of the Sequence will participate in |
|
type promotion. |
|
|
|
type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND. |
|
See its documentation for details. |
|
|
|
Other type promotion behavior, like validating the Python type of scalar arguments, must |
|
be handled separately. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, |
|
type_promoting_args: Sequence[str] = None, |
|
): |
|
self.type_promoting_arg_names = type_promoting_args |
|
self.type_promotion_kind = type_promotion_kind |
|
|
|
def __call__(self, fn: Callable) -> Callable: |
|
sig = inspect.signature(fn) |
|
|
|
@wraps(fn) |
|
def _fn(*args, **kwargs): |
|
bound = sig.bind(*args, **kwargs) |
|
type_promoting_args = tuple( |
|
bound.arguments[x] |
|
for x in self.type_promoting_arg_names |
|
if x in bound.arguments.keys() |
|
) |
|
|
|
flattened_type_promoting_args = tree_flatten(type_promoting_args)[0] |
|
compute_dtype, result_dtype = utils.elementwise_dtypes( |
|
*flattened_type_promoting_args, |
|
type_promotion_kind=self.type_promotion_kind, |
|
) |
|
|
|
promoted_args = { |
|
x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) |
|
for x in self.type_promoting_arg_names |
|
if x in bound.arguments.keys() |
|
} |
|
bound.arguments.update(promoted_args) |
|
|
|
result = fn(**bound.arguments) |
|
|
|
if isinstance(result, TensorLike): |
|
return _maybe_convert_to_dtype(result, result_dtype) |
|
if isinstance(result, Sequence): |
|
return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result) |
|
raise AssertionError(f"Unhandled result type: {type(result)}") |
|
|
|
_fn.__signature__ = sig |
|
return _fn |
|
|
|
|
|
|
|
def _maybe_resize_out(out: TensorLikeType, shape): |
|
if out.numel() == 0: |
|
return out.resize_(shape) |
|
|
|
if out.numel() != reduce(operator.mul, shape, 1): |
|
msg = ( |
|
"An output with one or more elements was resized since it had shape {0} " |
|
"which does not match the required output shape {1}. " |
|
"This behavior is deprecated, and in a future PyTorch release outputs will not " |
|
"be resized unless they have zero elements. " |
|
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0).".format( |
|
str(out.shape), str(shape) |
|
) |
|
) |
|
warnings.warn(msg) |
|
return out.resize_(shape) |
|
|
|
return out |
|
|
|
|
|
def _safe_copy_out( |
|
*, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False |
|
): |
|
|
|
if copy_from.device != copy_to.device: |
|
msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format( |
|
copy_from.device, copy_to.device |
|
) |
|
raise RuntimeError(msg) |
|
|
|
|
|
if exact_dtype: |
|
utils.check( |
|
copy_from.dtype == copy_to.dtype, |
|
lambda: f"Expected out tensor to have dtype {copy_from.dtype} " |
|
"but got {copy_to.dtype} instead", |
|
) |
|
else: |
|
utils.check( |
|
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), |
|
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " |
|
"but this can't be cast because it is not safe!", |
|
) |
|
|
|
return copy_to.copy_(copy_from) |
|
|
|
|
|
def out_wrapper(*out_names: str, exact_dtype: bool = False): |
|
is_tensor = len(out_names) == 0 |
|
assert is_tensor or len(out_names) >= 2 |
|
|
|
def _out_wrapper(fn: Callable) -> Callable: |
|
""" |
|
Adds the out parameter to a Python reference. |
|
""" |
|
out_type = ( |
|
TensorLikeType |
|
if is_tensor |
|
else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))] |
|
) |
|
return_type = ( |
|
TensorLikeType |
|
if is_tensor |
|
else NamedTuple( |
|
f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names] |
|
) |
|
) |
|
|
|
sig = inspect.signature(fn) |
|
factory_kwargs = ("device", "dtype") |
|
is_factory_fn = all(p in sig.parameters for p in factory_kwargs) |
|
|
|
@wraps(fn) |
|
def _fn(*args, out=None, **kwargs): |
|
if is_factory_fn and out is not None: |
|
for k in factory_kwargs: |
|
out_attr = getattr(out, k) |
|
if k not in kwargs: |
|
kwargs[k] = out_attr |
|
|
|
result = fn(*args, **kwargs) |
|
assert ( |
|
isinstance(result, TensorLike) |
|
and is_tensor |
|
or isinstance(result, Tuple) |
|
and len(result) == len(out_names) |
|
) |
|
if out is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_tensor: |
|
assert isinstance(out, TensorLike) |
|
|
|
_maybe_resize_out(out, result.shape) |
|
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) |
|
else: |
|
assert isinstance(out, Tuple) |
|
utils.check( |
|
len(out) == len(result), |
|
lambda: f"expected tuple of {len(result)} elements but got {len(out)}", |
|
TypeError, |
|
) |
|
for r, o in zip(result, out): |
|
|
|
_maybe_resize_out(o, r.shape) |
|
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) |
|
else: |
|
out = result |
|
|
|
return out if is_tensor else return_type(*out) |
|
|
|
out_param = inspect.Parameter( |
|
"out", |
|
kind=inspect.Parameter.KEYWORD_ONLY, |
|
default=None, |
|
annotation=out_type, |
|
) |
|
|
|
assert sig.return_annotation in (sig.empty, out_type) |
|
params = chain(sig.parameters.values(), (out_param,)) |
|
_fn.__signature__ = inspect.Signature( |
|
parameters=params, return_annotation=return_type |
|
) |
|
_fn.__annotations__ = fn.__annotations__ |
|
_fn.__annotations__["out"] = out_type |
|
_fn.__annotations__["return"] = return_type |
|
return _fn |
|
|
|
return _out_wrapper |
|
|
|
|
|
def backwards_not_supported(prim): |
|
def redispatch_prim(args, kwargs): |
|
g = torch._C._AutoDispatchBelowAutograd() |
|
try: |
|
return prim(*args, **kwargs) |
|
finally: |
|
del g |
|
|
|
class BackwardsNotSupported(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, args_spec, *flat_args): |
|
args, kwargs = tree_unflatten(flat_args, args_spec) |
|
return redispatch_prim(args, kwargs) |
|
|
|
@staticmethod |
|
def backward(ctx, *args): |
|
raise RuntimeError("backwards not supported on prim") |
|
|
|
@wraps(prim) |
|
def _autograd_impl(*args, **kwargs): |
|
flat_args, args_spec = tree_flatten((args, kwargs)) |
|
if torch.is_grad_enabled() and any(a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return BackwardsNotSupported.apply(args_spec, *flat_args) |
|
else: |
|
return redispatch_prim(args, kwargs) |
|
|
|
return _autograd_impl |
|
|
|
|
|
|
|
|
|
|
|
def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable: |
|
""" |
|
Allows unary operators that accept tensors to work with Python numbers. |
|
""" |
|
sig = inspect.signature(fn) |
|
|
|
@wraps(fn) |
|
def _fn(*args, **kwargs): |
|
if len(args) > 0 and isinstance(args[0], Number): |
|
dtype = utils.type_to_dtype(type(args[0])) |
|
args_ = list(args) |
|
args_[0] = torch.tensor(args[0], dtype=dtype) |
|
result = fn(*args_, **kwargs) |
|
assert isinstance(result, torch.Tensor) |
|
return result.item() |
|
|
|
return fn(*args, **kwargs) |
|
|
|
_fn.__signature__ = sig |
|
return _fn |
|
|