|
import sys |
|
import os |
|
import torch |
|
import warnings |
|
from contextlib import contextmanager |
|
from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation |
|
|
|
try: |
|
from torch._C import _cudnn |
|
except ImportError: |
|
_cudnn = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__cudnn_version = None |
|
|
|
if _cudnn is not None: |
|
def _init(): |
|
global __cudnn_version |
|
if __cudnn_version is None: |
|
__cudnn_version = _cudnn.getVersionInt() |
|
runtime_version = _cudnn.getRuntimeVersion() |
|
compile_version = _cudnn.getCompileVersion() |
|
runtime_major, runtime_minor, _ = runtime_version |
|
compile_major, compile_minor, _ = compile_version |
|
|
|
|
|
|
|
if runtime_major != compile_major: |
|
cudnn_compatible = False |
|
elif runtime_major < 7 or not _cudnn.is_cuda: |
|
cudnn_compatible = runtime_minor == compile_minor |
|
else: |
|
cudnn_compatible = runtime_minor >= compile_minor |
|
if not cudnn_compatible: |
|
base_error_msg = (f'cuDNN version incompatibility: ' |
|
f'PyTorch was compiled against {compile_version} ' |
|
f'but found runtime version {runtime_version}. ' |
|
f'PyTorch already comes bundled with cuDNN. ' |
|
f'One option to resolving this error is to ensure PyTorch ' |
|
f'can find the bundled cuDNN.') |
|
|
|
if 'LD_LIBRARY_PATH' in os.environ: |
|
ld_library_path = os.environ.get('LD_LIBRARY_PATH', '') |
|
if any(substring in ld_library_path for substring in ['cuda', 'cudnn']): |
|
raise RuntimeError(f'{base_error_msg}' |
|
f'Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn' |
|
f'Please either remove it from the path or install cudnn {compile_version}') |
|
else: |
|
raise RuntimeError(f'{base_error_msg}' |
|
f'one possibility is that there is a ' |
|
f'conflicting cuDNN in LD_LIBRARY_PATH.') |
|
else: |
|
raise RuntimeError(base_error_msg) |
|
|
|
return True |
|
else: |
|
def _init(): |
|
return False |
|
|
|
|
|
def version(): |
|
"""Returns the version of cuDNN""" |
|
if not _init(): |
|
return None |
|
return __cudnn_version |
|
|
|
|
|
CUDNN_TENSOR_DTYPES = { |
|
torch.half, |
|
torch.float, |
|
torch.double, |
|
} |
|
|
|
|
|
def is_available(): |
|
r"""Returns a bool indicating if CUDNN is currently available.""" |
|
return torch._C.has_cudnn |
|
|
|
|
|
def is_acceptable(tensor): |
|
if not torch._C._get_cudnn_enabled(): |
|
return False |
|
if tensor.device.type != 'cuda' or tensor.dtype not in CUDNN_TENSOR_DTYPES: |
|
return False |
|
if not is_available(): |
|
warnings.warn( |
|
"PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild " |
|
"PyTorch making sure the library is visible to the build system.") |
|
return False |
|
if not _init(): |
|
warnings.warn('cuDNN/MIOpen library not found. Check your {libpath}'.format( |
|
libpath={ |
|
'darwin': 'DYLD_LIBRARY_PATH', |
|
'win32': 'PATH' |
|
}.get(sys.platform, 'LD_LIBRARY_PATH'))) |
|
return False |
|
return True |
|
|
|
|
|
def set_flags(_enabled=None, _benchmark=None, _benchmark_limit=None, _deterministic=None, _allow_tf32=None): |
|
orig_flags = (torch._C._get_cudnn_enabled(), |
|
torch._C._get_cudnn_benchmark(), |
|
None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(), |
|
torch._C._get_cudnn_deterministic(), |
|
torch._C._get_cudnn_allow_tf32()) |
|
if _enabled is not None: |
|
torch._C._set_cudnn_enabled(_enabled) |
|
if _benchmark is not None: |
|
torch._C._set_cudnn_benchmark(_benchmark) |
|
if _benchmark_limit is not None and is_available(): |
|
torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit) |
|
if _deterministic is not None: |
|
torch._C._set_cudnn_deterministic(_deterministic) |
|
if _allow_tf32 is not None: |
|
torch._C._set_cudnn_allow_tf32(_allow_tf32) |
|
return orig_flags |
|
|
|
|
|
@contextmanager |
|
def flags(enabled=False, benchmark=False, benchmark_limit=10, deterministic=False, allow_tf32=True): |
|
with __allow_nonbracketed_mutation(): |
|
orig_flags = set_flags(enabled, benchmark, benchmark_limit, deterministic, allow_tf32) |
|
try: |
|
yield |
|
finally: |
|
|
|
with __allow_nonbracketed_mutation(): |
|
set_flags(*orig_flags) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class CudnnModule(PropModule): |
|
def __init__(self, m, name): |
|
super(CudnnModule, self).__init__(m, name) |
|
|
|
enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) |
|
deterministic = ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic) |
|
benchmark = ContextProp(torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark) |
|
benchmark_limit = None |
|
if is_available(): |
|
benchmark_limit = ContextProp(torch._C._cuda_get_cudnn_benchmark_limit, torch._C._cuda_set_cudnn_benchmark_limit) |
|
allow_tf32 = ContextProp(torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32) |
|
|
|
|
|
|
|
sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__) |
|
|
|
|
|
enabled: bool |
|
deterministic: bool |
|
benchmark: bool |
|
allow_tf32: bool |
|
benchmark_limit: int |
|
|