Example to load fails for GPUs with no bfloat16 support.

#1
by RASMUS - opened

Change load example for GPUs with no bfloat16 support to something like:

branch = "200B"
model = transformers.AutoModelForCausalLM.from_pretrained(
"LumiOpen/Viking-7B",
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
revision=branch,
)

LumiOpen org

ah, thanks!

Sign up or log in to comment