|
"""This module exist to be able to deprecate functions publicly without doing so internally. The deprecated |
|
public versions are defined in torch.testing._deprecated and exposed from torch.testing. The non-deprecated internal |
|
versions should be imported from torch.testing._internal |
|
""" |
|
|
|
from typing import List |
|
|
|
import torch |
|
|
|
__all_dtype_getters__ = [ |
|
"_validate_dtypes", |
|
"_dispatch_dtypes", |
|
"all_types", |
|
"all_types_and", |
|
"all_types_and_complex", |
|
"all_types_and_complex_and", |
|
"all_types_and_half", |
|
"complex_types", |
|
"empty_types", |
|
"floating_and_complex_types", |
|
"floating_and_complex_types_and", |
|
"floating_types", |
|
"floating_types_and", |
|
"double_types", |
|
"floating_types_and_half", |
|
"get_all_complex_dtypes", |
|
"get_all_dtypes", |
|
"get_all_fp_dtypes", |
|
"get_all_int_dtypes", |
|
"get_all_math_dtypes", |
|
"integral_types", |
|
"integral_types_and", |
|
] |
|
|
|
__all__ = [ |
|
*__all_dtype_getters__, |
|
"get_all_device_types", |
|
] |
|
|
|
|
|
|
|
|
|
|
|
def _validate_dtypes(*dtypes): |
|
for dtype in dtypes: |
|
assert isinstance(dtype, torch.dtype) |
|
return dtypes |
|
|
|
|
|
class _dispatch_dtypes(tuple): |
|
def __add__(self, other): |
|
assert isinstance(other, tuple) |
|
return _dispatch_dtypes(tuple.__add__(self, other)) |
|
|
|
_empty_types = _dispatch_dtypes(()) |
|
def empty_types(): |
|
return _empty_types |
|
|
|
_floating_types = _dispatch_dtypes((torch.float32, torch.float64)) |
|
def floating_types(): |
|
return _floating_types |
|
|
|
_floating_types_and_half = _floating_types + (torch.half,) |
|
def floating_types_and_half(): |
|
return _floating_types_and_half |
|
|
|
def floating_types_and(*dtypes): |
|
return _floating_types + _validate_dtypes(*dtypes) |
|
|
|
_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble) |
|
def floating_and_complex_types(): |
|
return _floating_and_complex_types |
|
|
|
def floating_and_complex_types_and(*dtypes): |
|
return _floating_and_complex_types + _validate_dtypes(*dtypes) |
|
|
|
_double_types = _dispatch_dtypes((torch.float64, torch.complex128)) |
|
def double_types(): |
|
return _double_types |
|
|
|
_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)) |
|
def integral_types(): |
|
return _integral_types |
|
|
|
def integral_types_and(*dtypes): |
|
return _integral_types + _validate_dtypes(*dtypes) |
|
|
|
_all_types = _floating_types + _integral_types |
|
def all_types(): |
|
return _all_types |
|
|
|
def all_types_and(*dtypes): |
|
return _all_types + _validate_dtypes(*dtypes) |
|
|
|
_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble)) |
|
def complex_types(): |
|
return _complex_types |
|
|
|
def complex_types_and(*dtypes): |
|
return _complex_types + _validate_dtypes(*dtypes) |
|
|
|
_all_types_and_complex = _all_types + _complex_types |
|
def all_types_and_complex(): |
|
return _all_types_and_complex |
|
|
|
def all_types_and_complex_and(*dtypes): |
|
return _all_types_and_complex + _validate_dtypes(*dtypes) |
|
|
|
_all_types_and_half = _all_types + (torch.half,) |
|
def all_types_and_half(): |
|
return _all_types_and_half |
|
|
|
|
|
|
|
|
|
def get_all_dtypes(include_half=True, |
|
include_bfloat16=True, |
|
include_bool=True, |
|
include_complex=True, |
|
include_complex32=False, |
|
include_qint=False, |
|
) -> List[torch.dtype]: |
|
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16) |
|
if include_bool: |
|
dtypes.append(torch.bool) |
|
if include_complex: |
|
dtypes += get_all_complex_dtypes(include_complex32) |
|
if include_qint: |
|
dtypes += get_all_qint_dtypes() |
|
return dtypes |
|
|
|
def get_all_math_dtypes(device) -> List[torch.dtype]: |
|
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'), |
|
include_bfloat16=False) + get_all_complex_dtypes() |
|
|
|
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]: |
|
return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128] |
|
|
|
|
|
def get_all_int_dtypes() -> List[torch.dtype]: |
|
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] |
|
|
|
|
|
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]: |
|
dtypes = [torch.float32, torch.float64] |
|
if include_half: |
|
dtypes.append(torch.float16) |
|
if include_bfloat16: |
|
dtypes.append(torch.bfloat16) |
|
return dtypes |
|
|
|
|
|
def get_all_qint_dtypes() -> List[torch.dtype]: |
|
return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4] |
|
|
|
|
|
def get_all_device_types() -> List[str]: |
|
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] |
|
|