File size: 5,133 Bytes
9dd3461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""This module exist to be able to deprecate functions publicly without doing so internally. The deprecated
public versions are defined in torch.testing._deprecated and exposed from torch.testing. The non-deprecated internal
versions should be imported from torch.testing._internal
"""

from typing import List

import torch

__all_dtype_getters__ = [
    "_validate_dtypes",
    "_dispatch_dtypes",
    "all_types",
    "all_types_and",
    "all_types_and_complex",
    "all_types_and_complex_and",
    "all_types_and_half",
    "complex_types",
    "empty_types",
    "floating_and_complex_types",
    "floating_and_complex_types_and",
    "floating_types",
    "floating_types_and",
    "double_types",
    "floating_types_and_half",
    "get_all_complex_dtypes",
    "get_all_dtypes",
    "get_all_fp_dtypes",
    "get_all_int_dtypes",
    "get_all_math_dtypes",
    "integral_types",
    "integral_types_and",
]

__all__ = [
    *__all_dtype_getters__,
    "get_all_device_types",
]

# Functions and classes for describing the dtypes a function supports
# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros

# Verifies each given dtype is a torch.dtype
def _validate_dtypes(*dtypes):
    for dtype in dtypes:
        assert isinstance(dtype, torch.dtype)
    return dtypes

# class for tuples corresponding to a PyTorch dispatch macro
class _dispatch_dtypes(tuple):
    def __add__(self, other):
        assert isinstance(other, tuple)
        return _dispatch_dtypes(tuple.__add__(self, other))

_empty_types = _dispatch_dtypes(())
def empty_types():
    return _empty_types

_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
def floating_types():
    return _floating_types

_floating_types_and_half = _floating_types + (torch.half,)
def floating_types_and_half():
    return _floating_types_and_half

def floating_types_and(*dtypes):
    return _floating_types + _validate_dtypes(*dtypes)

_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
def floating_and_complex_types():
    return _floating_and_complex_types

def floating_and_complex_types_and(*dtypes):
    return _floating_and_complex_types + _validate_dtypes(*dtypes)

_double_types = _dispatch_dtypes((torch.float64, torch.complex128))
def double_types():
    return _double_types

_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64))
def integral_types():
    return _integral_types

def integral_types_and(*dtypes):
    return _integral_types + _validate_dtypes(*dtypes)

_all_types = _floating_types + _integral_types
def all_types():
    return _all_types

def all_types_and(*dtypes):
    return _all_types + _validate_dtypes(*dtypes)

_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
def complex_types():
    return _complex_types

def complex_types_and(*dtypes):
    return _complex_types + _validate_dtypes(*dtypes)

_all_types_and_complex = _all_types + _complex_types
def all_types_and_complex():
    return _all_types_and_complex

def all_types_and_complex_and(*dtypes):
    return _all_types_and_complex + _validate_dtypes(*dtypes)

_all_types_and_half = _all_types + (torch.half,)
def all_types_and_half():
    return _all_types_and_half

# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro

# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
def get_all_dtypes(include_half=True,
                   include_bfloat16=True,
                   include_bool=True,
                   include_complex=True,
                   include_complex32=False,
                   include_qint=False,
                   ) -> List[torch.dtype]:
    dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
    if include_bool:
        dtypes.append(torch.bool)
    if include_complex:
        dtypes += get_all_complex_dtypes(include_complex32)
    if include_qint:
        dtypes += get_all_qint_dtypes()
    return dtypes

def get_all_math_dtypes(device) -> List[torch.dtype]:
    return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
                                                    include_bfloat16=False) + get_all_complex_dtypes()

def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
    return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128]


def get_all_int_dtypes() -> List[torch.dtype]:
    return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]


def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
    dtypes = [torch.float32, torch.float64]
    if include_half:
        dtypes.append(torch.float16)
    if include_bfloat16:
        dtypes.append(torch.bfloat16)
    return dtypes


def get_all_qint_dtypes() -> List[torch.dtype]:
    return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]


def get_all_device_types() -> List[str]:
    return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']