olm-chat-7b / open_lm /precision.py
henhenhahi111112's picture
Upload folder using huggingface_hub
af6e330 verified
raw
history blame
531 Bytes
import torch
from contextlib import suppress
def get_autocast(precision):
if precision == "amp":
return torch.cuda.amp.autocast if torch.cuda.is_available() else torch.cpu.amp.autocast
elif precision == "amp_bfloat16" or precision == "amp_bf16":
# amp_bfloat16 is more stable than amp float16 for clip training
autocast_fn = torch.cuda.amp.autocast if torch.cuda.is_available() else torch.cpu.amp.autocast
return lambda: autocast_fn(dtype=torch.bfloat16)
else:
return suppress