zdou0830's picture
desco
749745d
from contextlib import contextmanager
@contextmanager
def nullcontext(enter_result=None, **kwargs):
yield enter_result
try:
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
except:
print("[Warning] Library for automatic mixed precision is not found, AMP is disabled!!")
GradScaler = nullcontext
autocast = nullcontext
custom_fwd = nullcontext
custom_bwd = nullcontext