import torch from contextlib import suppress def get_autocast(precision): if precision == "amp": return torch.cuda.amp.autocast if torch.cuda.is_available() else torch.cpu.amp.autocast elif precision == "amp_bfloat16" or precision == "amp_bf16": # amp_bfloat16 is more stable than amp float16 for clip training autocast_fn = torch.cuda.amp.autocast if torch.cuda.is_available() else torch.cpu.amp.autocast return lambda: autocast_fn(dtype=torch.bfloat16) else: return suppress