Guess-What-Moves / determinism.py
subhc's picture
Code Commit
5e88f62
import os
lvl = int(os.environ.get('TRY_DETERMISM_LVL', '0'))
if lvl > 0:
print(f'Attempting to enable deterministic cuDNN and cuBLAS operations to lvl {lvl}')
if lvl >= 2:
# turn on deterministic operations
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" #Need to set before torch gets loaded
import torch
# Since using unstable torch version, it looks like 1.12.0.devXXXXXXX
if torch.version.__version__ >= '1.12.0':
torch.use_deterministic_algorithms(True, warn_only=(lvl < 3))
elif lvl >= 3:
torch.use_deterministic_algorithms(True) # This will throw errors if implementations are missing
else:
print(f"Torch verions is only {torch.version.__version__}, which will cause errors on lvl {lvl}")
if lvl >= 1:
import torch
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
def i_do_nothing_but_dont_remove_me_otherwise_things_break():
"""This exists to prevent formatters from treating this file as dead code"""
pass