UMMJ's picture
Upload 5875 files
9dd3461
import torch
from torch._prims_common import (
Number,
NumberType,
TensorLike,
TensorLikeType,
ELEMENTWISE_TYPE_PROMOTION_KIND,
)
import torch._prims_common as utils
from torch.utils._pytree import tree_flatten, tree_unflatten
from typing import Callable, Sequence, Union, Tuple, NamedTuple
import inspect
from functools import wraps, reduce
import operator
import warnings
from itertools import chain
# TODO: implement ref.cast with an option to enforce safe casting
def _maybe_convert_to_dtype(
a: Union[TensorLikeType, NumberType, Sequence, None], dtype: torch.dtype
) -> Union[TensorLikeType, NumberType, Sequence, None]:
import torch._prims as prims
if isinstance(a, TensorLike):
if a.dtype != dtype:
# NOTE: this is incorrect on the CPU
# See https://github.com/pytorch/pytorch/issues/77553
return prims.convert_element_type(a, dtype)
return a
if isinstance(a, Number):
return utils.dtype_to_type_ctor(dtype)(a)
if isinstance(a, Sequence):
return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
# Passthrough None because some functions wrapped with type promotion
# wrapper might have optional args
if a is None:
return None
raise ValueError(
"Received type {0} that is neither a tensor or a number!".format(type(a))
)
def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType:
if not isinstance(a, Number):
msg = "Found unknown type {0} when trying to convert scalars!".format(type(a))
raise ValueError(msg)
if not utils.is_weakly_lesser_type(type(a), typ):
msg = "Scalar {0} of type {1} cannot be safely cast to type {2}!".format(
a, type(a), typ
)
raise ValueError(msg)
return typ(a)
def _annotation_has_type(*, typ, annotation):
if hasattr(annotation, "__args__"):
for a in annotation.__args__:
if _annotation_has_type(typ=typ, annotation=a):
return True
return False
return typ is annotation
class elementwise_type_promotion_wrapper(object):
"""
Adds elementwise type promotion to a Python reference implementation.
Takes two kwargs, type_promoting_args and type_promotion_kind.
type_promoting_args must be a string Sequence specifiying the argument names of all
arguments that participate in type promotion (and should be type promoted). If the
arg specifies a Sequence-type then every element of the Sequence will participate in
type promotion.
type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND.
See its documentation for details.
Other type promotion behavior, like validating the Python type of scalar arguments, must
be handled separately.
"""
def __init__(
self,
*,
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
type_promoting_args: Sequence[str] = None,
):
self.type_promoting_arg_names = type_promoting_args
self.type_promotion_kind = type_promotion_kind
def __call__(self, fn: Callable) -> Callable:
sig = inspect.signature(fn)
@wraps(fn)
def _fn(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
type_promoting_args = tuple(
bound.arguments[x]
for x in self.type_promoting_arg_names # type: ignore[union-attr]
if x in bound.arguments.keys()
)
flattened_type_promoting_args = tree_flatten(type_promoting_args)[0]
compute_dtype, result_dtype = utils.elementwise_dtypes(
*flattened_type_promoting_args,
type_promotion_kind=self.type_promotion_kind,
)
promoted_args = {
x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
for x in self.type_promoting_arg_names # type: ignore[union-attr]
if x in bound.arguments.keys()
}
bound.arguments.update(promoted_args)
result = fn(**bound.arguments)
if isinstance(result, TensorLike):
return _maybe_convert_to_dtype(result, result_dtype)
if isinstance(result, Sequence):
return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
raise AssertionError(f"Unhandled result type: {type(result)}")
_fn.__signature__ = sig # type: ignore[attr-defined]
return _fn
# TODO: handle tuples of tensors
def _maybe_resize_out(out: TensorLikeType, shape):
if out.numel() == 0:
return out.resize_(shape)
if out.numel() != reduce(operator.mul, shape, 1):
msg = (
"An output with one or more elements was resized since it had shape {0} "
"which does not match the required output shape {1}. "
"This behavior is deprecated, and in a future PyTorch release outputs will not "
"be resized unless they have zero elements. "
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0).".format(
str(out.shape), str(shape)
)
)
warnings.warn(msg)
return out.resize_(shape)
return out
def _safe_copy_out(
*, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
):
# Checks same device
if copy_from.device != copy_to.device:
msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format(
copy_from.device, copy_to.device
)
raise RuntimeError(msg)
# Checks safe cast
if exact_dtype:
utils.check(
copy_from.dtype == copy_to.dtype,
lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
"but got {copy_to.dtype} instead",
)
else:
utils.check(
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
"but this can't be cast because it is not safe!",
)
return copy_to.copy_(copy_from)
def out_wrapper(*out_names: str, exact_dtype: bool = False):
is_tensor = len(out_names) == 0
assert is_tensor or len(out_names) >= 2
def _out_wrapper(fn: Callable) -> Callable:
"""
Adds the out parameter to a Python reference.
"""
out_type = (
TensorLikeType
if is_tensor
else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
)
return_type = (
TensorLikeType
if is_tensor
else NamedTuple(
f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names]
)
)
sig = inspect.signature(fn)
factory_kwargs = ("device", "dtype")
is_factory_fn = all(p in sig.parameters for p in factory_kwargs)
@wraps(fn)
def _fn(*args, out=None, **kwargs):
if is_factory_fn and out is not None:
for k in factory_kwargs:
out_attr = getattr(out, k)
if k not in kwargs:
kwargs[k] = out_attr
result = fn(*args, **kwargs)
assert (
isinstance(result, TensorLike)
and is_tensor
or isinstance(result, Tuple) # type: ignore[arg-type]
and len(result) == len(out_names)
)
if out is not None:
# Naively you might expect this assert to be true, but
# it's not:
#
# assert type(out) == type(result)
#
# The reason is that functions under this wrapper can
# get registered to the Meta dispatch key, and that
# means they can be executed in a context where tensor
# subclasses are disabled (with no_dispatch), which is a
# handy way for an is-a tensor subclass (e.g.,
# FakeTensor) to have the normal meta backend create a
# meta tensor, to be wrapped once it gets returned.
# In this situation, you will get a FakeTensor as
# the output tensor, but not the result--which will
# be a normal meta tensor, but this is perfectly
# harmless.
if is_tensor:
assert isinstance(out, TensorLike)
# These two operations are done in-place
_maybe_resize_out(out, result.shape)
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
assert isinstance(out, Tuple) # type: ignore[arg-type]
utils.check(
len(out) == len(result),
lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
TypeError,
)
for r, o in zip(result, out):
# These two operations are done in-place
_maybe_resize_out(o, r.shape)
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
out = result
# mypy does not see through the definition of out_type given that it's in a different scope
return out if is_tensor else return_type(*out) # type: ignore[operator]
out_param = inspect.Parameter(
"out",
kind=inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=out_type,
)
# Mark that the function now returns a tuple
assert sig.return_annotation in (sig.empty, out_type)
params = chain(sig.parameters.values(), (out_param,))
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=params, return_annotation=return_type # type: ignore[arg-type]
)
_fn.__annotations__ = fn.__annotations__
_fn.__annotations__["out"] = out_type
_fn.__annotations__["return"] = return_type
return _fn
return _out_wrapper
def backwards_not_supported(prim):
def redispatch_prim(args, kwargs):
g = torch._C._AutoDispatchBelowAutograd()
try:
return prim(*args, **kwargs)
finally:
del g
class BackwardsNotSupported(torch.autograd.Function):
@staticmethod
def forward(ctx, args_spec, *flat_args):
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
return redispatch_prim(args, kwargs)
@staticmethod
def backward(ctx, *args):
raise RuntimeError("backwards not supported on prim")
@wraps(prim)
def _autograd_impl(*args, **kwargs):
flat_args, args_spec = tree_flatten((args, kwargs))
if torch.is_grad_enabled() and any(a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)):
# TODO: There is a subtle bug here: prims like copy_to
# return their input argument after mutating it; and custom
# autograd function will incorrectly turn the result into
# a view which will fail test_python_ref_executor tests.
# At the moment, we sidestep this by observing that the
# unit tests don't ever try to run the executor with
# autograd, so we don't exercise the buggy case, but if
# you ever want to feed autograd through this, be aware
# of it! We need a way of properly implementing autograd
# for mutating operations in Python to do this.
return BackwardsNotSupported.apply(args_spec, *flat_args)
else:
return redispatch_prim(args, kwargs)
return _autograd_impl
# TODO: when tracing this will add torch tensors and not TensorMeta objects
# to the trace -- we should fix this by adding a tracing context and NumberMeta classes
# TODO: this wrapper is currently untested
def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable:
"""
Allows unary operators that accept tensors to work with Python numbers.
"""
sig = inspect.signature(fn)
@wraps(fn)
def _fn(*args, **kwargs):
if len(args) > 0 and isinstance(args[0], Number):
dtype = utils.type_to_dtype(type(args[0]))
args_ = list(args)
args_[0] = torch.tensor(args[0], dtype=dtype)
result = fn(*args_, **kwargs)
assert isinstance(result, torch.Tensor)
return result.item()
return fn(*args, **kwargs)
_fn.__signature__ = sig # type: ignore[attr-defined]
return _fn