|
r"""This file is allowed to initialize CUDA context when imported.""" |
|
|
|
import functools |
|
import torch |
|
import torch.cuda |
|
from torch.testing._internal.common_utils import TEST_NUMBA, IS_WINDOWS, TEST_WITH_ROCM |
|
import inspect |
|
import contextlib |
|
from distutils.version import LooseVersion |
|
|
|
|
|
TEST_CUDA = torch.cuda.is_available() |
|
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 |
|
CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None |
|
|
|
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)) |
|
TEST_CUDNN_VERSION = torch.backends.cudnn.version() if TEST_CUDNN else 0 |
|
|
|
CUDA11OrLater = torch.version.cuda and LooseVersion(torch.version.cuda) >= "11.0" |
|
CUDA9 = torch.version.cuda and torch.version.cuda.startswith('9.') |
|
SM53OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3) |
|
SM60OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0) |
|
SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) |
|
|
|
TEST_MAGMA = TEST_CUDA |
|
if TEST_CUDA: |
|
torch.ones(1).cuda() |
|
TEST_MAGMA = torch.cuda.has_magma |
|
|
|
if TEST_NUMBA: |
|
import numba.cuda |
|
TEST_NUMBA_CUDA = numba.cuda.is_available() |
|
else: |
|
TEST_NUMBA_CUDA = False |
|
|
|
|
|
|
|
__cuda_ctx_rng_initialized = False |
|
|
|
|
|
|
|
def initialize_cuda_context_rng(): |
|
global __cuda_ctx_rng_initialized |
|
assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng' |
|
if not __cuda_ctx_rng_initialized: |
|
|
|
for i in range(torch.cuda.device_count()): |
|
torch.randn(1, device="cuda:{}".format(i)) |
|
__cuda_ctx_rng_initialized = True |
|
|
|
|
|
|
|
|
|
|
|
def tf32_is_not_fp32(): |
|
if not torch.cuda.is_available() or torch.version.cuda is None: |
|
return False |
|
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: |
|
return False |
|
if int(torch.version.cuda.split('.')[0]) < 11: |
|
return False |
|
return True |
|
|
|
|
|
@contextlib.contextmanager |
|
def tf32_off(): |
|
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 |
|
try: |
|
torch.backends.cuda.matmul.allow_tf32 = False |
|
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False): |
|
yield |
|
finally: |
|
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul |
|
|
|
|
|
@contextlib.contextmanager |
|
def tf32_on(self, tf32_precision=1e-5): |
|
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 |
|
old_precision = self.precision |
|
try: |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
self.precision = tf32_precision |
|
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True): |
|
yield |
|
finally: |
|
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul |
|
self.precision = old_precision |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tf32_on_and_off(tf32_precision=1e-5): |
|
def with_tf32_disabled(self, function_call): |
|
with tf32_off(): |
|
function_call() |
|
|
|
def with_tf32_enabled(self, function_call): |
|
with tf32_on(self, tf32_precision): |
|
function_call() |
|
|
|
def wrapper(f): |
|
params = inspect.signature(f).parameters |
|
arg_names = tuple(params.keys()) |
|
|
|
@functools.wraps(f) |
|
def wrapped(*args, **kwargs): |
|
for k, v in zip(arg_names, args): |
|
kwargs[k] = v |
|
cond = tf32_is_not_fp32() |
|
if 'device' in kwargs: |
|
cond = cond and (torch.device(kwargs['device']).type == 'cuda') |
|
if 'dtype' in kwargs: |
|
cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64}) |
|
if cond: |
|
with_tf32_disabled(kwargs['self'], lambda: f(**kwargs)) |
|
with_tf32_enabled(kwargs['self'], lambda: f(**kwargs)) |
|
else: |
|
f(**kwargs) |
|
|
|
return wrapped |
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def with_tf32_off(f): |
|
@functools.wraps(f) |
|
def wrapped(*args, **kwargs): |
|
with tf32_off(): |
|
return f(*args, **kwargs) |
|
|
|
return wrapped |
|
|
|
def _get_magma_version(): |
|
if 'Magma' not in torch.__config__.show(): |
|
return (0, 0) |
|
position = torch.__config__.show().find('Magma ') |
|
version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0] |
|
return tuple(int(x) for x in version_str.split(".")) |
|
|
|
def _get_torch_cuda_version(): |
|
if torch.version.cuda is None: |
|
return (0, 0) |
|
cuda_version = str(torch.version.cuda) |
|
return tuple(int(x) for x in cuda_version.split(".")) |
|
|
|
def _check_cusparse_generic_available(): |
|
version = _get_torch_cuda_version() |
|
min_supported_version = (10, 1) |
|
if IS_WINDOWS: |
|
min_supported_version = (11, 0) |
|
return version >= min_supported_version |
|
|
|
def _check_hipsparse_generic_available(): |
|
if not TEST_WITH_ROCM: |
|
return False |
|
|
|
rocm_version = str(torch.version.hip) |
|
rocm_version = rocm_version.split("-")[0] |
|
rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) |
|
return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1)) |
|
|
|
|
|
TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available() |
|
TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available() |
|
|