import torch | |
from contextlib import suppress | |
def get_autocast(precision): | |
if precision == "amp": | |
return torch.cuda.amp.autocast if torch.cuda.is_available() else torch.cpu.amp.autocast | |
elif precision == "amp_bfloat16" or precision == "amp_bf16": | |
# amp_bfloat16 is more stable than amp float16 for clip training | |
autocast_fn = torch.cuda.amp.autocast if torch.cuda.is_available() else torch.cpu.amp.autocast | |
return lambda: autocast_fn(dtype=torch.bfloat16) | |
else: | |
return suppress | |