Load model in bfloat16

#2

The model is currently loaded without specifying a torch_dtype, which in my testing defaults to loading the model in torch.float32.

This PR loads the model in torch.bfloat16, which is the same dtype as used during training. It should lower memory requirements by about a factor 2, but more importantly: generation should also be sped up by around the same factor 2 without loss of quality.

Rijgersberg changed pull request title from Load model in bloat16 to Load model in bfloat16
BramVanroy changed pull request status to merged

Thanks! I ignorantly assumed that dtype=auto would take care of this when the safetensors metadata is all BF16.

Sign up or log in to comment