File size: 1,022 Bytes
5e88f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
lvl = int(os.environ.get('TRY_DETERMISM_LVL', '0'))
if lvl > 0:
    print(f'Attempting to enable deterministic cuDNN and cuBLAS operations to lvl {lvl}')
if lvl >= 2:
    # turn on deterministic operations
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"  #Need to set before torch gets loaded
    import torch
    # Since using unstable torch version, it looks like 1.12.0.devXXXXXXX
    if torch.version.__version__ >= '1.12.0':
        torch.use_deterministic_algorithms(True, warn_only=(lvl < 3))
    elif lvl >= 3:
        torch.use_deterministic_algorithms(True)  # This will throw errors if implementations are missing
    else:
        print(f"Torch verions is only {torch.version.__version__}, which will cause errors on lvl {lvl}")
if lvl >= 1:
    import torch
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False


def i_do_nothing_but_dont_remove_me_otherwise_things_break():
    """This exists to prevent formatters from treating this file as dead code"""
    pass