import torch from contextlib import suppress def get_autocast(precision): if precision == 'amp': return torch.cuda.amp.autocast elif precision == 'amp_bfloat16' or precision == 'amp_bf16': # amp_bfloat16 is more stable than amp float16 for clip training return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) else: return suppress