|
from __future__ import annotations |
|
|
|
from typing import Any, Union, Sequence, Optional, Tuple, List, Callable, Type, overload, cast |
|
from enum import Enum |
|
from functools import reduce, cmp_to_key |
|
import operator |
|
import weakref |
|
|
|
import torch |
|
|
|
|
|
if hasattr(torch._C, "_nvfuser"): |
|
from torch._C._nvfuser import DataType |
|
|
|
_torch_dtype_to_nvfuser_dtype_map = { |
|
torch.cdouble: DataType.ComplexDouble, |
|
torch.cfloat: DataType.ComplexFloat, |
|
torch.double: DataType.Double, |
|
torch.float: DataType.Float, |
|
torch.half: DataType.Half, |
|
torch.bfloat16: DataType.BFloat16, |
|
torch.long: DataType.Int, |
|
torch.int: DataType.Int32, |
|
torch.bool: DataType.Bool, |
|
|
|
complex: DataType.ComplexDouble, |
|
float: DataType.Double, |
|
int: DataType.Int, |
|
bool: DataType.Bool, |
|
} |
|
else: |
|
_torch_dtype_to_nvfuser_dtype_map = {} |
|
|
|
|
|
def getnvFuserDtype(dtype: Union[torch.dtype, NumberTypeType]): |
|
""" |
|
Translates from torch.dtype to nvFuser's DataType enum |
|
""" |
|
return _torch_dtype_to_nvfuser_dtype_map[dtype] |
|
|
|
|
|
ShapeType = Union[torch.Size, List[int], Tuple[int, ...]] |
|
StrideType = Union[List[int], Tuple[int, ...]] |
|
DimsType = Union[int, List[int], Tuple[int, ...]] |
|
DimsSequenceType = Union[List[int], Tuple[int, ...]] |
|
|
|
NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]] |
|
|
|
|
|
NumberType = Union[bool, int, float, complex] |
|
Number = (bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode) |
|
DeviceLikeType = Union[str, torch.device] |
|
Tensor = torch.Tensor |
|
|
|
|
|
torch_function_passthrough = { |
|
torch.Tensor.dim, |
|
torch.Tensor.ndim.__get__, |
|
torch.Tensor.numel, |
|
torch.Tensor.size, |
|
torch.Tensor.storage_offset, |
|
torch.Tensor.stride, |
|
torch.Tensor.dtype.__get__, |
|
torch.Tensor.is_sparse.__get__, |
|
torch.Tensor.shape.__get__, |
|
torch.Tensor.device.__get__, |
|
torch.Tensor.requires_grad.__get__, |
|
torch.Tensor.layout.__get__, |
|
|
|
torch.Tensor.__format__, |
|
torch.Tensor.__repr__, |
|
torch.Tensor.requires_grad.__get__, |
|
} |
|
|
|
|
|
TensorLikeType = torch.Tensor |
|
TensorLike = torch.Tensor |
|
TensorSequenceType = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]] |
|
TensorOrNumberLikeType = Union[TensorLikeType, NumberType] |
|
|
|
|
|
def same_shape(a: ShapeType, b: ShapeType) -> bool: |
|
if len(a) != len(b): |
|
return False |
|
|
|
for x, y in zip(a, b): |
|
if x != y: |
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType, check_strides=False): |
|
""" |
|
Checks that two tensor likes have the same shape, |
|
dtype and device. |
|
|
|
In the future this will validate additional metadata, like |
|
strides. |
|
""" |
|
assert isinstance(a, TensorLike) |
|
assert isinstance(b, TensorLike) |
|
|
|
if not same_shape(a.shape, b.shape): |
|
msg = "Shapes {0} and {1} are not equal!".format(a.shape, b.shape) |
|
raise AssertionError(msg) |
|
|
|
if a.dtype != b.dtype: |
|
msg = "Dtypes {0} and {1} are not equal!".format(a.dtype, b.dtype) |
|
raise AssertionError(msg) |
|
|
|
if a.device != b.device: |
|
|
|
|
|
if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and ( |
|
str(b.device) == "cuda:0" or str(b.device) == "cuda" |
|
): |
|
pass |
|
else: |
|
msg = "Devices {0} and {1} are not equal!".format(a.device, b.device) |
|
raise AssertionError(msg) |
|
|
|
|
|
if check_strides: |
|
same_strides, idx = check_significant_strides(a, b) |
|
if not same_strides: |
|
msg = ( |
|
"Stride mismatch! Strides are {0} and {1} (mismatched at {2})!".format( |
|
a.stride(), b.stride(), idx |
|
) |
|
) |
|
raise RuntimeError(msg) |
|
|
|
|
|
def check_significant_strides( |
|
a: TensorLikeType, b: TensorLikeType |
|
) -> Tuple[bool, Optional[int]]: |
|
|
|
|
|
|
|
|
|
if (a.device.type == "cuda" or b.device.type == "cuda") and a.numel() > 0: |
|
for idx in range(a.ndim): |
|
if a.stride()[idx] != b.stride()[idx] and a.shape[idx] > 1: |
|
return False, idx |
|
|
|
return True, None |
|
|
|
|
|
|
|
def is_contiguous(a: TensorLikeType) -> bool: |
|
""" |
|
Tests whether a tensor is contiguous or not. |
|
|
|
Tensors are contiguous when they have no elements, |
|
one element, or when they have "nested" strides. |
|
""" |
|
if a.numel() < 2: |
|
return True |
|
|
|
expected_stride = 1 |
|
for x, y in reversed(tuple(zip(a.shape, a.stride()))): |
|
|
|
if x == 1: |
|
continue |
|
|
|
if y != expected_stride: |
|
return False |
|
expected_stride = expected_stride * x |
|
|
|
return True |
|
|
|
|
|
|
|
def is_channels_last_contiguous_2d(a: Tensor) -> bool: |
|
|
|
if a.ndim != 4: |
|
return False |
|
|
|
expected_stride = 1 |
|
for idx in (1, 3, 2, 0): |
|
|
|
length = a.shape[idx] |
|
if length == 1: |
|
continue |
|
|
|
stride = a.stride()[idx] |
|
if stride != expected_stride: |
|
return False |
|
|
|
expected_stride *= length |
|
|
|
return True |
|
|
|
|
|
def is_channels_last_contiguous_3d(a: Tensor) -> bool: |
|
|
|
if a.ndim != 5: |
|
return False |
|
|
|
expected_stride = 1 |
|
for idx in (1, 4, 3, 2, 0): |
|
|
|
length = a.shape[idx] |
|
if length == 1: |
|
continue |
|
|
|
stride = a.stride()[idx] |
|
if stride != expected_stride: |
|
return False |
|
|
|
expected_stride *= length |
|
|
|
return True |
|
|
|
|
|
_memory_formats = set( |
|
( |
|
torch.contiguous_format, |
|
torch.preserve_format, |
|
torch.channels_last, |
|
torch.channels_last_3d, |
|
) |
|
) |
|
|
|
|
|
def validate_memory_format(memory_format: torch.memory_format): |
|
check( |
|
memory_format in _memory_formats, |
|
lambda: f"Received unknown memory format {memory_format}!", |
|
) |
|
|
|
|
|
def is_contiguous_for_memory_format( |
|
a: Tensor, *, memory_format: torch.memory_format |
|
) -> bool: |
|
validate_memory_format(memory_format) |
|
|
|
if memory_format == torch.contiguous_format: |
|
return is_contiguous(a) |
|
if memory_format == torch.channels_last: |
|
return is_channels_last_contiguous_2d(a) |
|
if memory_format == torch.channels_last_3d: |
|
return is_channels_last_contiguous_3d(a) |
|
|
|
check( |
|
False, |
|
lambda: f"is_contiguous received unsupported memory format {memory_format}", |
|
) |
|
|
|
|
|
|
|
def is_channels_last_contiguous(a: Tensor) -> bool: |
|
""" |
|
True when a tensor is channels-last contiguous. |
|
|
|
This requires that: |
|
|
|
- the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions |
|
- if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the |
|
stride of the 'C' dimension (Cs) is 1 and the strides corresponding to |
|
each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are |
|
"nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension, |
|
for example. |
|
""" |
|
return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a) |
|
|
|
|
|
def is_non_overlapping_and_dense(a: Tensor) -> bool: |
|
""" |
|
True when a tensor is non-overlapping and dense. |
|
|
|
A tensor is non-overlapping and dense when there exists a permutation of |
|
its dimensions that is contiguous. |
|
""" |
|
|
|
|
|
if is_contiguous(a) or is_channels_last_contiguous(a): |
|
return True |
|
|
|
|
|
|
|
|
|
|
|
if a.ndim == 1: |
|
return a.stride()[0] == 1 |
|
|
|
|
|
|
|
lengths_and_strides = sorted( |
|
tuple(zip(a.shape, a.stride())), key=operator.itemgetter(1) |
|
) |
|
|
|
expected_stride = 1 |
|
for length, stride in lengths_and_strides: |
|
|
|
if length == 1: |
|
continue |
|
|
|
if stride != expected_stride: |
|
return False |
|
|
|
expected_stride *= length |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]: |
|
""" |
|
Computes the output strides for elementwise operations. |
|
""" |
|
|
|
if len(tensors) == 0: |
|
msg = "Can't compute elementwise output strides for zero tensors!" |
|
raise ValueError(msg) |
|
|
|
check_same_shape(*tensors, allow_cpu_scalar_tensors=True) |
|
|
|
|
|
tensors = tuple( |
|
a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) |
|
) |
|
|
|
|
|
if len(tensors) == 0: |
|
return () |
|
|
|
|
|
|
|
ndim = tensors[0].ndim |
|
if ndim == 0: |
|
return () |
|
if ndim == 1: |
|
return (1,) |
|
|
|
shape = tensors[0].shape |
|
|
|
def _cmp(idx_a, idx_b): |
|
for tensor in tensors: |
|
stride_a = tensor.stride()[idx_a] |
|
stride_b = tensor.stride()[idx_b] |
|
|
|
if stride_a == 0 or stride_b == 0: |
|
continue |
|
|
|
if stride_a < stride_b: |
|
return -1 |
|
|
|
if stride_a > stride_b: |
|
return 1 |
|
|
|
|
|
if shape[idx_a] > shape[idx_b]: |
|
return 1 |
|
|
|
|
|
if shape[idx_a] < shape[idx_b]: |
|
return -1 |
|
|
|
|
|
|
|
return 0 |
|
|
|
perm = tuple(range(ndim)) |
|
perm = tuple(sorted(perm, key=cmp_to_key(_cmp), reverse=True)) |
|
|
|
permuted_shape = [-1] * ndim |
|
for idx, x in enumerate(perm): |
|
permuted_shape[idx] = shape[x] |
|
|
|
new_strides = make_contiguous_strides_for(permuted_shape) |
|
permuted_strides = [-1] * ndim |
|
for idx, x in enumerate(perm): |
|
permuted_strides[x] = new_strides[idx] |
|
|
|
return tuple(permuted_strides) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_dim_length(length: int): |
|
""" |
|
Validates that an object represents a valid |
|
dimension length. |
|
""" |
|
|
|
assert length >= 0 |
|
|
|
|
|
def validate_shape(shape: ShapeType): |
|
""" |
|
Validates that a sequence represents a valid shape. |
|
""" |
|
|
|
assert isinstance(shape, Sequence) |
|
for l in shape: |
|
validate_dim_length(l) |
|
|
|
|
|
def validate_strides(strides: StrideType): |
|
""" |
|
Verifies the object specifies valid strides. |
|
""" |
|
|
|
assert isinstance(strides, Sequence) |
|
for stride in strides: |
|
assert stride >= 0 |
|
|
|
|
|
def validate_idx(rank: int, idx: int): |
|
""" |
|
Validates that idx is a valid index for the given shape. |
|
Assumes the index is already canonicalized. |
|
""" |
|
|
|
assert isinstance(idx, int) |
|
assert isinstance(rank, int) |
|
|
|
assert idx >= 0 and idx < rank or idx == 0 |
|
|
|
|
|
def validate_dimension_indices(rank: int, indices: DimsSequenceType): |
|
for idx in indices: |
|
validate_idx(rank, idx) |
|
|
|
|
|
def validate_exclusive_idx(rank: int, ex_idx: int): |
|
""" |
|
Validates that ex_idx is a valid exclusive index |
|
for the given shape. |
|
""" |
|
|
|
assert isinstance(ex_idx, int) |
|
assert isinstance(rank, int) |
|
assert ex_idx > 0 and ex_idx <= rank |
|
|
|
|
|
|
|
|
|
|
|
def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int: |
|
if rank < 0: |
|
msg = f"Rank cannot be negative but got {rank}" |
|
raise IndexError(msg) |
|
|
|
if rank == 0: |
|
if not wrap_scalar: |
|
msg = f"Dimension specified as {idx} but tensor has no dimensions" |
|
raise IndexError(msg) |
|
rank = 1 |
|
|
|
if idx >= 0 and idx < rank: |
|
return idx |
|
|
|
if idx < 0: |
|
_idx = idx + rank |
|
else: |
|
_idx = idx |
|
|
|
if _idx < 0 or _idx >= rank: |
|
|
|
msg = "Dimension out of range (expected to be in range of [{0}, {1}], but got {2})".format( |
|
-rank, rank - 1, idx |
|
) |
|
raise IndexError(msg) |
|
|
|
return _idx |
|
|
|
|
|
|
|
|
|
@overload |
|
def canonicalize_dims(rank: int, indices: Sequence[int]) -> Tuple[int, ...]: |
|
pass |
|
|
|
|
|
@overload |
|
def canonicalize_dims(rank: int, indices: int) -> int: |
|
pass |
|
|
|
|
|
def canonicalize_dims(rank, indices): |
|
if isinstance(indices, int): |
|
return canonicalize_dim(rank, indices) |
|
|
|
return tuple(canonicalize_dim(rank, x) for x in indices) |
|
|
|
|
|
def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool: |
|
""" |
|
Validates that perm is a permutation of length rank. |
|
""" |
|
|
|
if not isinstance(perm, Sequence): |
|
return False |
|
|
|
if not (tuple(sorted(perm)) == tuple(range(0, rank))): |
|
return False |
|
|
|
return True |
|
|
|
|
|
def is_same_shape(a: Sequence, b: Sequence) -> bool: |
|
""" |
|
Compares two shapes a and b, returning True if they are the same |
|
(their ranks and corresponding lengths match) and False otherwise. |
|
""" |
|
|
|
return tuple(a) == tuple(b) |
|
|
|
|
|
def is_cpu_scalar_tensor(a: Any) -> bool: |
|
return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu" |
|
|
|
|
|
def check_same_device(*args, allow_cpu_scalar_tensors): |
|
""" |
|
Checks that all Tensors in args have the same device. |
|
|
|
Raises a RuntimeError when: |
|
- args contains an object whose type is not Tensor or Number |
|
- two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True |
|
""" |
|
|
|
if len(args) <= 1: |
|
return |
|
|
|
|
|
device = None |
|
for arg in args: |
|
if isinstance(arg, Number): |
|
continue |
|
elif isinstance(arg, TensorLike): |
|
if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): |
|
continue |
|
|
|
if device is None: |
|
device = arg.device |
|
|
|
if device != arg.device: |
|
msg = ( |
|
"Tensor on device " |
|
+ str(arg.device) |
|
+ " is not on the expected device " |
|
+ str(device) |
|
+ "!" |
|
) |
|
raise RuntimeError(msg) |
|
else: |
|
msg = ( |
|
"Unexpected type when checking for same device, " + str(type(arg)) + "!" |
|
) |
|
raise RuntimeError(msg) |
|
|
|
|
|
def canonicalize_device(device: DeviceLikeType) -> torch.device: |
|
if isinstance(device, torch.device): |
|
return device |
|
|
|
assert isinstance(device, str) |
|
return torch.device(device) |
|
|
|
|
|
|
|
|
|
|
|
def check_same_shape(*args, allow_cpu_scalar_tensors: bool): |
|
""" |
|
Checks that all Tensors in args have the same shape. |
|
|
|
Raises a RuntimeError when: |
|
- args contains an object whose type is not Tensor or Number |
|
- two Tensor objects in args have different devices |
|
""" |
|
shape = None |
|
|
|
for arg in args: |
|
if isinstance(arg, Number): |
|
continue |
|
elif isinstance(arg, TensorLike): |
|
if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): |
|
continue |
|
|
|
if shape is None: |
|
shape = arg.shape |
|
|
|
if not is_same_shape(shape, arg.shape): |
|
msg = "Shape {0} is not the expected shape {1}!".format( |
|
arg.shape, shape |
|
) |
|
raise RuntimeError(msg) |
|
else: |
|
msg = ( |
|
"Unexpected type when checking for same shape, " + str(type(arg)) + "!" |
|
) |
|
raise RuntimeError(msg) |
|
|
|
|
|
|
|
|
|
def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]: |
|
shape = None |
|
scalar_shape = None |
|
|
|
for arg in args: |
|
if isinstance(arg, Number): |
|
continue |
|
elif isinstance(arg, TensorLike): |
|
if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): |
|
scalar_shape = arg.shape |
|
continue |
|
|
|
if shape is None: |
|
shape = arg.shape |
|
|
|
if not is_same_shape(shape, arg.shape): |
|
return None |
|
else: |
|
return None |
|
|
|
return shape if shape is not None else scalar_shape |
|
|
|
|
|
|
|
|
|
def extract_dims_from_varargs(dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]]) -> DimsSequenceType: |
|
if dims and isinstance(dims[0], Sequence): |
|
assert len(dims) == 1 |
|
dims = cast(Tuple[DimsSequenceType], dims) |
|
return dims[0] |
|
else: |
|
return cast(DimsSequenceType, dims) |
|
|
|
|
|
def extract_shape_from_varargs( |
|
shape: Union[ShapeType, Tuple[ShapeType]], |
|
validate=True, |
|
) -> Tuple[int, ...]: |
|
""" |
|
Returns a shape from varargs. |
|
|
|
In PyTorch, operations that accept shapes often accept them as varargs, like |
|
foo(*shape). However a user can pass the shape as a sequence of integers, |
|
like this: |
|
|
|
foo(1, 2, 3) |
|
|
|
or as a sequence of integers |
|
|
|
foo((1, 2, 3)) |
|
|
|
In the first case shape will be a tuple of integers, and in the second case it's a tuple |
|
containing a tuple of integers. This validates those inputs and canonicalizes them |
|
to a tuple of integers. |
|
""" |
|
|
|
|
|
if len(shape) == 1 and isinstance(shape[0], Sequence): |
|
shape = shape[0] |
|
|
|
if validate: |
|
validate_shape(shape) |
|
return shape |
|
|
|
|
|
def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]: |
|
""" |
|
Infers the size of a dim with size -1, if it exists. |
|
Also checks that new shape is compatible with the number of elements. |
|
""" |
|
dim = None |
|
newsize = 1 |
|
for i, d in enumerate(shape): |
|
if d == -1: |
|
check(dim is None, lambda: "only one dimension can be inferred") |
|
dim = i |
|
elif d >= 0: |
|
newsize *= d |
|
else: |
|
check(False, lambda: f"invalid shape dimension {d}") |
|
check( |
|
numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0), |
|
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", |
|
) |
|
if dim is not None: |
|
check( |
|
newsize != 0, |
|
lambda: f"cannot reshape tensor fo 0 elements into shape {shape} because the " |
|
f"unspecified dimension size -1 can be any value and is ambiguous", |
|
) |
|
shape = list(shape) |
|
shape[dim] = numel // newsize |
|
return tuple(shape) |
|
|
|
|
|
_integer_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) |
|
_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) |
|
_float_dtypes = (torch.float16, torch.bfloat16, torch.float32, torch.float64) |
|
_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128) |
|
|
|
|
|
def is_boolean_dtype(dtype: torch.dtype) -> bool: |
|
assert isinstance(dtype, torch.dtype) |
|
return dtype is torch.bool |
|
|
|
|
|
def is_integer_dtype(dtype: torch.dtype) -> bool: |
|
assert isinstance(dtype, torch.dtype) |
|
return dtype in _integer_dtypes |
|
|
|
|
|
def is_low_precision_dtype(dtype: torch.dtype) -> bool: |
|
assert isinstance(dtype, torch.dtype) |
|
return dtype in _low_precision_dtypes |
|
|
|
|
|
def is_float_dtype(dtype: torch.dtype) -> bool: |
|
assert isinstance(dtype, torch.dtype) |
|
return dtype in _float_dtypes |
|
|
|
|
|
def is_complex_dtype(dtype: torch.dtype) -> bool: |
|
assert isinstance(dtype, torch.dtype) |
|
return dtype in _complex_dtypes |
|
|
|
|
|
def is_grad_dtype(dtype: torch.dtype) -> bool: |
|
""" |
|
Checks if the dtype can require a gradient. |
|
""" |
|
return is_float_dtype(dtype) or is_complex_dtype(dtype) |
|
|
|
|
|
_complex_to_real_dtype_map = { |
|
torch.complex128: torch.float64, |
|
torch.complex64: torch.float32, |
|
torch.complex32: torch.float16, |
|
} |
|
|
|
_real_to_complex_dtype_map = { |
|
torch.float16: torch.complex32, |
|
torch.bfloat16: torch.complex64, |
|
torch.float32: torch.complex64, |
|
torch.float64: torch.complex128, |
|
} |
|
|
|
|
|
def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype: |
|
return _complex_to_real_dtype_map[dtype] |
|
|
|
|
|
def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype: |
|
return _real_to_complex_dtype_map[dtype] |
|
|
|
|
|
def dtype_to_type(dtype: torch.dtype) -> type: |
|
""" |
|
Computes the corresponding Python type (AKA "type kind") for the |
|
given dtype. |
|
""" |
|
assert isinstance(dtype, torch.dtype) |
|
|
|
if dtype is torch.bool: |
|
return bool |
|
if dtype in _integer_dtypes: |
|
return int |
|
if dtype in _float_dtypes: |
|
return float |
|
if dtype in _complex_dtypes: |
|
return complex |
|
|
|
raise ValueError("Invalid dtype!") |
|
|
|
|
|
def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]: |
|
""" |
|
Computes the corresponding Python type constructor for the |
|
given dtype. |
|
""" |
|
from torch.fx.experimental.symbolic_shapes import sym_float |
|
|
|
assert isinstance(dtype, torch.dtype) |
|
|
|
if dtype is torch.bool: |
|
return lambda x: bool(x) |
|
if dtype in _integer_dtypes: |
|
|
|
return lambda x: int(x) |
|
if dtype in _float_dtypes: |
|
return sym_float |
|
if dtype in _complex_dtypes: |
|
|
|
return lambda x: complex(x) |
|
|
|
raise ValueError("Invalid dtype!") |
|
|
|
|
|
def type_to_dtype(typ: type) -> torch.dtype: |
|
""" |
|
Computes the corresponding dtype for a Number type. |
|
""" |
|
|
|
assert isinstance(typ, type) |
|
|
|
if typ is bool: |
|
return torch.bool |
|
if typ is int: |
|
return torch.long |
|
if typ is float: |
|
return torch.get_default_dtype() |
|
if typ is complex: |
|
return corresponding_complex_dtype(torch.get_default_dtype()) |
|
|
|
raise ValueError("Invalid type!") |
|
|
|
|
|
def get_dtype(x: Union[torch.Tensor, NumberType]): |
|
if isinstance(x, torch.Tensor): |
|
return x.dtype |
|
else: |
|
return type_to_dtype(type(x)) |
|
|
|
|
|
_ordered_types = (bool, int, float, complex) |
|
|
|
|
|
def check_fp_or_complex( |
|
dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True |
|
): |
|
""" |
|
Checks whether the input is floating point or complex. |
|
If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32 |
|
""" |
|
check( |
|
is_float_dtype(dtype) or is_complex_dtype(dtype), |
|
lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}", |
|
) |
|
check( |
|
allow_low_precision_dtypes or not is_low_precision_dtype(dtype), |
|
lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}", |
|
) |
|
|
|
|
|
def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"): |
|
check( |
|
len(A.shape) >= 2, |
|
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", |
|
) |
|
|
|
|
|
def get_higher_type(a: type, b: type) -> type: |
|
""" |
|
Returns the higher of the two given Number types. |
|
|
|
The types are ordered bool -> int -> float -> complex. |
|
""" |
|
|
|
assert a in _ordered_types |
|
assert b in _ordered_types |
|
|
|
if a is b: |
|
return a |
|
|
|
for typ in _ordered_types: |
|
if a is typ: |
|
return b |
|
if b is typ: |
|
return a |
|
|
|
raise ValueError("Unknown Python scalar type!") |
|
|
|
|
|
|
|
|
|
|
|
def get_higher_dtype( |
|
a: Optional[Union[torch.dtype, TensorLikeType, NumberType]], |
|
b: Optional[Union[torch.dtype, TensorLikeType, NumberType]], |
|
) -> Optional[torch.dtype]: |
|
""" |
|
Computes the "lowest" datatype that is weakly |
|
"higher" than both a and b. |
|
""" |
|
|
|
|
|
assert a is None or isinstance(a, (torch.dtype, TensorLike, Number)) |
|
assert b is None or isinstance(b, (torch.dtype, TensorLike, Number)) |
|
|
|
def _extract_dtype( |
|
x: Optional[Union[torch.dtype, TensorLikeType, NumberType]] |
|
) -> Optional[torch.dtype]: |
|
if x is None: |
|
return None |
|
if isinstance(x, torch.dtype): |
|
return x |
|
if isinstance(x, TensorLike): |
|
return x.dtype |
|
if isinstance(x, Number): |
|
return type_to_dtype(type(x)) |
|
|
|
raise RuntimeError("Unexpected type given to _extract_dtype!") |
|
|
|
a, b = _extract_dtype(a), _extract_dtype(b) |
|
|
|
if a is b: |
|
return a |
|
|
|
if a is None: |
|
return b |
|
|
|
if b is None: |
|
return a |
|
|
|
ordered_datatypes = ( |
|
(torch.bool,), |
|
(torch.uint8, torch.int8), |
|
(torch.int16,), |
|
(torch.int32,), |
|
(torch.int64,), |
|
(torch.float16, torch.bfloat16), |
|
(torch.float32,), |
|
(torch.float64,), |
|
(torch.complex32,), |
|
(torch.complex64,), |
|
(torch.complex128,), |
|
) |
|
|
|
for idx, dtypes in enumerate(ordered_datatypes): |
|
if a in dtypes and b in dtypes: |
|
return ordered_datatypes[idx + 1][0] |
|
if a in dtypes: |
|
return b |
|
if b in dtypes: |
|
return a |
|
|
|
raise RuntimeError("Unexpected termination!") |
|
|
|
|
|
def check_pin_memory(pin_memory: bool): |
|
check(not pin_memory, lambda: "PrimTorch does not support pinned memory", NotImplementedError) |
|
|
|
|
|
def check_layout(layout: torch.layout): |
|
check(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}", NotImplementedError) |
|
|
|
|
|
|
|
def is_weakly_lesser_type(a: type, b: type) -> bool: |
|
""" |
|
Compares two types, a and b, returning True if a is weakly "less" than b. |
|
|
|
The comparison is determined by the following type ordering: bool, int, float, complex. |
|
""" |
|
ordered_types = ( |
|
bool, |
|
int, |
|
float, |
|
complex, |
|
) |
|
|
|
assert a in ordered_types |
|
assert b in ordered_types |
|
|
|
for typ in ordered_types: |
|
if a == typ: |
|
return True |
|
if b == typ: |
|
return False |
|
|
|
raise RuntimeError("Unexpected termination!") |
|
|
|
|
|
def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool: |
|
for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype): |
|
if fn(cast_to): |
|
return True |
|
if fn(cast_from): |
|
return False |
|
|
|
raise ValueError("Received unknown dtypes {0}, {1}!".format(cast_to, cast_from)) |
|
|
|
|
|
def check_same_dtype(*args): |
|
""" |
|
Checks that all Tensors in args have the same device and that all Numbers have the |
|
same corresponding Python type. |
|
|
|
Raises a RuntimeError when: |
|
- args contains an object whose type is not Tensor or Number |
|
- two Tensors objects in args have different dtypes |
|
- two Number objects in args have different types |
|
- there are Tensors and Numbers in args, and one of those Tensors corresponding |
|
Python types is different from the type of one of those Numbers |
|
""" |
|
full_dtype = None |
|
scalar_type = None |
|
|
|
for arg in args: |
|
if isinstance(arg, Number): |
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif isinstance(arg, TensorLike): |
|
if full_dtype is None: |
|
full_dtype = arg.dtype |
|
if scalar_type is None: |
|
scalar_type = dtype_to_type(arg.dtype) |
|
|
|
if full_dtype is not arg.dtype: |
|
msg = ( |
|
"Tensor with dtype " |
|
+ str(arg.dtype) |
|
+ " is not the expected dtype of " |
|
+ str(full_dtype) |
|
+ "!" |
|
) |
|
raise RuntimeError(msg) |
|
|
|
arg_type = dtype_to_type(arg.dtype) |
|
if arg_type is not scalar_type: |
|
msg = ( |
|
"Tensor with corresponding Python type " |
|
+ str(arg_type) |
|
+ " is not the expected type of " |
|
+ str(scalar_type) |
|
+ "!" |
|
) |
|
raise RuntimeError(msg) |
|
else: |
|
msg = ( |
|
"Unexpected type when checking for same dtype, " + str(type(arg)) + "!" |
|
) |
|
raise RuntimeError(msg) |
|
|
|
|
|
|
|
_computation_dtype_map = { |
|
torch.bfloat16: torch.float32, |
|
torch.float16: torch.float32, |
|
torch.complex32: torch.complex64, |
|
} |
|
|
|
|
|
def get_computation_dtype(dtype: torch.dtype) -> torch.dtype: |
|
return _computation_dtype_map.get(dtype, dtype) |
|
|
|
|
|
class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum): |
|
DEFAULT = (0,) |
|
NO_OPMATH = (1,) |
|
INT_TO_FLOAT = (2,) |
|
ALWAYS_BOOL = (3,) |
|
COMPLEX_TO_FLOAT = (4,) |
|
BOOL_TO_LONG = (5,) |
|
|
|
|
|
class REDUCTION_OUTPUT_TYPE_KIND(Enum): |
|
SAME = (0,) |
|
COMPLEX_TO_FLOAT = (1,) |
|
KEEP_PROMOTED_TYPE = (2,) |
|
ALWAYS_BOOL = (3,) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RETURN_TYPE(Enum): |
|
NEW = (0,) |
|
VIEW = (1,) |
|
INPLACE = (2,) |
|
|
|
|
|
|
|
def number_type(x: Union[NumberType, torch.SymIntNode, torch.SymFloatNode]) -> Type: |
|
if isinstance(x, torch.SymIntNode): |
|
return int |
|
elif isinstance(x, torch.SymFloatNode): |
|
return float |
|
else: |
|
return type(x) |
|
|
|
|
|
|
|
def elementwise_dtypes( |
|
*_args, |
|
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, |
|
) -> Tuple[torch.dtype, torch.dtype]: |
|
""" |
|
Computes the computation and result dtypes for elementwise type promotion |
|
on the given arguments and with the given elementwise type promotion kind. |
|
|
|
Note that not all inputs to an elementwise operation necessarily participate in type promotion. |
|
For example, the "alpha" parameter of torch.add does not participate in type promotion, |
|
although it may be cast to the Python type corresponding to the computation dtype that |
|
the type promotion algorithm determines. |
|
|
|
Default elementwise type promotion, which all other type promotion kinds tweak (see below), |
|
first decides which of four ordered types to use: |
|
|
|
bool -> integer -> floating point -> complex |
|
|
|
The selected type is the "lowest" type in the above list such that all number arguments |
|
have a weakly "lower" type and all tensor arguments have a weakly lower corresponding |
|
type for their dtype. |
|
|
|
Once the type is determined, the particular result dtype is found. The dtypes are |
|
partially ordered as follows: |
|
|
|
bool -> uint8, int8 -> int16 -> int32 -> int64 -> |
|
float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128 |
|
|
|
The result dtype is selected by: |
|
- if no tensor's dtype has the same corresponding type as the one selected, |
|
then the result dtype is the (default) dtype corresponding to the selected type |
|
(for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype) |
|
- if the result type is complex then the dtype is: |
|
- the default complex dtype if there are no floating point or complex tensors |
|
- if there are floating point or complex tensors with one or more dimensions, then |
|
the complex dtype corresponding to the highest corresponding complex dtype among those tensors |
|
(for example, double + cfloat -> cdouble) |
|
- if there are only floating point or complex tensors with zero dimensions, then |
|
the complex dtype corresponding to the highest corresponding complex dtype among those tensors |
|
- if the first two cases do not apply, the result dtype is the highest dtype among |
|
all tensors with one or more dimensions of the output type, and if there are no such |
|
tensors then it's the highest dtype among all tensors with zero dimensions of the output type |
|
(for example, long + half -> half, even if the half tensor has zero dimensions) |
|
|
|
The "corresponding complex dtypes" are: |
|
float16 -> complex32 |
|
bfloat16 -> complex64 |
|
float32 -> complex64 |
|
float64 -> complex128 |
|
complex32 -> complex32 |
|
complex64 -> complex64 |
|
complex128 -> complex128 |
|
|
|
The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation |
|
dtype by mapping low precision floating point and complex dtypes as follows: |
|
|
|
float16 -> float32 |
|
bfloat16 -> float32 |
|
complex32 -> complex64 |
|
|
|
This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the |
|
computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels |
|
which perform no mathematical operations on their tensors (see below for examples). |
|
|
|
The INT_TO_FLOAT type promotion kind maps boolean and integer maps result dtypes to the default floating point dtype, |
|
and computation dtypes to the appropriate op math dtype. |
|
|
|
The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this |
|
mapping: |
|
|
|
complex32 -> float16 |
|
complex64 -> float32 |
|
complex128 -> float64 |
|
|
|
Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does. |
|
|
|
The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long. |
|
|
|
The ALWAYS_BOOL type promotion kind always sets the result dtype to bool. |
|
|
|
Example operators for each type promotion option: |
|
DEFAULT : add |
|
NO_OPMATH : where, nextafter, cat |
|
INT_TO_FLOAT : sin |
|
COMPLEX_TO_FLOAT : abs |
|
BOOL_TO_LONG : pow |
|
ALWAYS_BOOL : eq |
|
|
|
""" |
|
|
|
args = tuple(x for x in _args if x is not None) |
|
|
|
highest_type: type = bool |
|
for x in args: |
|
if not isinstance(x, (Number, TensorLike)): |
|
msg = ( |
|
"Unexpected type {0} when computing elementwise type promotion!".format( |
|
str(type(x)) |
|
) |
|
) |
|
raise ValueError(msg) |
|
|
|
if isinstance(x, Number): |
|
highest_type = get_higher_type(highest_type, number_type(x)) |
|
else: |
|
|
|
highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype)) |
|
|
|
result_dtype = None |
|
|
|
def _find_highest_dtype_filtered( |
|
args, filter, *, float_as_complex=False |
|
) -> Optional[torch.dtype]: |
|
zero_dim_tensor_dtype = None |
|
one_plus_dim_tensor_dtype = None |
|
for x in args: |
|
if isinstance(x, TensorLike) and filter(x.dtype): |
|
_dtype = x.dtype |
|
if float_as_complex and is_float_dtype(_dtype): |
|
_dtype = corresponding_complex_dtype(_dtype) |
|
if x.ndim == 0: |
|
zero_dim_tensor_dtype = get_higher_dtype( |
|
zero_dim_tensor_dtype, _dtype |
|
) |
|
else: |
|
|
|
one_plus_dim_tensor_dtype = get_higher_dtype( |
|
one_plus_dim_tensor_dtype, _dtype |
|
) |
|
|
|
|
|
if one_plus_dim_tensor_dtype is not None: |
|
return one_plus_dim_tensor_dtype |
|
|
|
return zero_dim_tensor_dtype |
|
|
|
if highest_type is float: |
|
result_dtype = _find_highest_dtype_filtered(args, is_float_dtype) |
|
result_dtype = ( |
|
torch.get_default_dtype() if result_dtype is None else result_dtype |
|
) |
|
elif highest_type is complex: |
|
result_dtype = _find_highest_dtype_filtered( |
|
args, |
|
lambda x: is_float_dtype(x) or is_complex_dtype(x), |
|
float_as_complex=True, |
|
) |
|
if result_dtype is None: |
|
result_dtype = corresponding_complex_dtype(torch.get_default_dtype()) |
|
elif highest_type is int: |
|
result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype) |
|
result_dtype = torch.long if result_dtype is None else result_dtype |
|
else: |
|
|
|
result_dtype = torch.bool |
|
|
|
if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: |
|
return get_computation_dtype(result_dtype), result_dtype |
|
elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH: |
|
return result_dtype, result_dtype |
|
elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT: |
|
if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype): |
|
result_dtype = torch.get_default_dtype() |
|
return get_computation_dtype(result_dtype), result_dtype |
|
elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: |
|
|
|
computation_dtype = get_computation_dtype(result_dtype) |
|
if is_complex_dtype(result_dtype): |
|
result_dtype = corresponding_real_dtype(result_dtype) |
|
return computation_dtype, result_dtype |
|
elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG: |
|
if is_boolean_dtype(result_dtype): |
|
return torch.long, torch.long |
|
return get_computation_dtype(result_dtype), result_dtype |
|
elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL: |
|
return get_computation_dtype(result_dtype), torch.bool |
|
else: |
|
raise ValueError( |
|
"Unknown type promotion kind {0}".format(str(type_promotion_kind)) |
|
) |
|
|
|
|
|
def reduction_dtypes( |
|
arg, |
|
output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, |
|
dtype: Optional[torch.dtype] = None, |
|
) -> Tuple[torch.dtype, Optional[torch.dtype]]: |
|
|
|
|
|
|
|
inp_dtype = dtype if dtype is not None else arg.dtype |
|
computation_dtype = get_computation_dtype(inp_dtype) |
|
if ( |
|
output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME |
|
or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT |
|
): |
|
result_dtype = dtype if dtype else arg.dtype |
|
if ( |
|
output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT |
|
and is_complex_dtype(result_dtype) |
|
): |
|
result_dtype = corresponding_real_dtype(result_dtype) |
|
elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE: |
|
result_dtype = None |
|
else: |
|
result_dtype = torch.bool |
|
return computation_dtype, result_dtype |
|
|
|
|
|
def make_contiguous_strides_for( |
|
shape: ShapeType, row_major: bool = True |
|
) -> Tuple[int, ...]: |
|
""" |
|
Returns the strides of a contriguous tensor if row_major |
|
If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices |
|
This is often used when calling external libraries like BLAS/LAPACK/cuSolver... |
|
""" |
|
validate_shape(shape) |
|
if not shape: |
|
return () |
|
|
|
multiplier = 1 |
|
strides = [] |
|
for l in reversed(shape): |
|
strides.append(multiplier) |
|
if l != 0: |
|
multiplier *= l |
|
|
|
result = tuple(reversed(strides)) |
|
|
|
if row_major: |
|
return result |
|
else: |
|
if len(shape) < 2: |
|
return result |
|
return result[:-2] + (1, max(shape[-2], 1)) |
|
|
|
|
|
def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]: |
|
|
|
check( |
|
len(shape) == 4, |
|
lambda: "Only tensors of rank 4 can use the channels_last memory format", |
|
) |
|
|
|
multiplier = 1 |
|
strides = [0] * 4 |
|
for idx in (1, -1, -2, 0): |
|
|
|
|
|
strides[idx] = multiplier |
|
multiplier *= shape[idx] |
|
|
|
return tuple(strides) |
|
|
|
|
|
def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]: |
|
check( |
|
len(shape) == 5, |
|
lambda: "Only tensors of rank 5 can use the channels_last_3d memory format", |
|
) |
|
|
|
multiplier = 1 |
|
strides = [0] * 5 |
|
for idx in (1, -1, -2, -3, 0): |
|
|
|
|
|
strides[idx] = multiplier |
|
multiplier *= shape[idx] |
|
|
|
return tuple(strides) |
|
|
|
|
|
def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]: |
|
ndim = len(shape) if isinstance(shape, Sequence) else 1 |
|
if ndim == 4: |
|
return make_channels_last_2d_strides_for(shape) |
|
elif ndim == 5: |
|
return make_channels_last_3d_strides_for(shape) |
|
else: |
|
raise RuntimeError( |
|
f"no channels last format strides exist in {ndim} dimensions" |
|
) |
|
|
|
|
|
def compute_reduction_output_shape( |
|
shape: ShapeType, dimensions: Sequence |
|
) -> Tuple[int, ...]: |
|
for idx in dimensions: |
|
validate_idx(len(shape), idx) |
|
|
|
new_shape = [] |
|
for idx in range(len(shape)): |
|
if idx in dimensions: |
|
continue |
|
|
|
new_shape.append(shape[idx]) |
|
|
|
return tuple(new_shape) |
|
|
|
|
|
def validate_no_repeating_dims(dims: Sequence): |
|
if len(dims) != len(set(dims)): |
|
raise RuntimeError("duplicate value in the list of dims") |
|
|
|
|
|
def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]: |
|
if dims is None: |
|
return tuple(range(len(shape))) |
|
dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims) |
|
validate_no_repeating_dims(dims) |
|
return dims |
|
|
|
|
|
def set_correction( |
|
unbiased: Optional[bool] = None, |
|
correction: Optional[int] = None, |
|
): |
|
if correction is not None and unbiased is not None: |
|
raise RuntimeError("cannot specify both correction and unbiased arguments") |
|
elif correction is None and unbiased is None: |
|
correction = 1 |
|
elif correction is None and unbiased is not None: |
|
correction = 0 if unbiased is False else 1 |
|
if not isinstance(correction, int): |
|
raise ValueError("correction argument should be integer") |
|
if correction < 0: |
|
raise ValueError("correction argument should be non-negative") |
|
return correction |
|
|
|
|
|
def check_in_bounds_for_storage( |
|
a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int |
|
): |
|
""" |
|
Determines if the given shape, strides, and offset are valid for the given storage. |
|
""" |
|
|
|
|
|
if reduce(operator.mul, shape) == 0: |
|
return |
|
|
|
length = a.size() - storage_offset |
|
max_offset = 0 |
|
for x, y in zip(shape, strides): |
|
max_offset = max_offset + (x - 1) * y |
|
|
|
if max_offset >= length: |
|
required_length = max_offset + storage_offset |
|
msg = ( |
|
"Can't view a storage of size {0} with an offset of {1}, shape of {2}, and strides of {3}, " |
|
"which requires a storage of size {4}".format( |
|
a.size(), storage_offset, str(shape), str(strides), required_length |
|
) |
|
) |
|
raise ValueError(msg) |
|
|
|
|
|
def check( |
|
b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError |
|
) -> None: |
|
""" |
|
Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails. |
|
Error message is a callable producing a string (to avoid wasting time |
|
string formatting in non-error case, and also to make it easier for torchdynamo |
|
to trace.) |
|
""" |
|
if not b: |
|
raise exc_type(s()) |
|
|
|
|
|
|
|
|
|
def are_strides_like_channels_last( |
|
shape: Sequence[int], strides: Sequence[int] |
|
) -> bool: |
|
ndim = len(shape) |
|
|
|
if ndim == 4: |
|
|
|
dim_order = [1, 3, 2, 0] |
|
elif ndim == 5: |
|
|
|
dim_order = [1, 4, 3, 2, 0] |
|
else: |
|
return False |
|
|
|
if strides[1] == 0: |
|
return False |
|
|
|
min = 0 |
|
for d in dim_order: |
|
if shape[d] == 0: |
|
return False |
|
if strides[d] < min: |
|
return False |
|
if d == 0 and min == strides[1]: |
|
return False |
|
min = strides[d] |
|
if strides[d] > 1: |
|
min *= shape[d] |
|
return True |
|
|
|
|
|
def suggest_memory_format(x: TensorLikeType) -> torch.memory_format: |
|
if x.layout != torch.strided: |
|
return torch.contiguous_format |
|
|
|
if are_strides_like_channels_last(x.shape, x.stride()): |
|
return torch.channels_last if x.ndim == 4 else torch.channels_last_3d |
|
|
|
return torch.contiguous_format |
|
|
|
|
|
def prod(xs: Sequence[NumberType]) -> NumberType: |
|
"""Product of elements in input sequence. Returns 1 for empty sequence""" |
|
return reduce(operator.mul, xs, 1) |
|
|
|
|
|
def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool: |
|
"""Checks if a shape can be expanded to another shape. |
|
This is equivalent to checking if the two shapes are broadcastable. |
|
""" |
|
|
|
|
|
if len(shape) > len(desired): |
|
return False |
|
for i in range(len(shape)): |
|
if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1: |
|
return False |
|
return True |
|
|
|
|
|
def mask_tensor(mask: TensorLikeType, t: TensorLikeType): |
|
""" |
|
Similar to torch.where(mask, t, 0) but if t is boolean, |
|
result is also boolean and not promoted to int. |
|
""" |
|
|
|
|
|
if t.dtype is torch.bool: |
|
return mask.logical_and(t) |
|
else: |
|
return torch.where(mask, t, 0) |
|
|
|
|
|
def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype: |
|
return dtype if dtype is not None else torch.get_default_dtype() |
|
|
|
|
|
def device_or_default(device: Optional[torch.device]) -> torch.device: |
|
return device if device is not None else torch.device("cpu") |
|
|
|
|
|
def layout_or_default(layout: Optional[torch.layout]) -> torch.layout: |
|
return layout if layout is not None else torch.strided |
|
|