Why set `dtype=torch.float32` in https://huggingface.co/mosaicml/mpt-7b/commit/c7c00aa2f381ea13ee3feef3b29f026fc61617ad?

#57
by emilylearning - opened

Seems mismatch with Model Card suggestion:
torch_dtype=torch.bfloat16, # Load model weights in bfloat16

Causing:
RuntimeError: expected scalar type BFloat16 but found Float

Pinningrevision="2addd09bac4237fcd63421de158186abaede0285" fixes issue.

+1, this also caused issues on our end today.

RuntimeError: expected scalar type Half but found Float both when loading models out of the box and when trying to load in 8bit.

You're running the model in lower precision (fp16 or bf16), but alibi bias needs to be in fp32 or else the model perf degrades. to get those to work together correctly, you should use autocast. Here is an example of how we had to update our tests to get this right: https://github.com/mosaicml/llm-foundry/pull/329/files#diff-3b8a58a4d021803b3171b886bb9162fd659e671131f3f61036f9210cb5d0bc7cR809

sam-mosaic changed discussion status to closed

Sign up or log in to comment