Spaces:
Sleeping
Sleeping
Load model in bfloat16
#2
by
Rijgersberg
- opened
The model is currently loaded without specifying a torch_dtype
, which in my testing defaults to loading the model in torch.float32
.
This PR loads the model in torch.bfloat16
, which is the same dtype as used during training. It should lower memory requirements by about a factor 2, but more importantly: generation should also be sped up by around the same factor 2 without loss of quality.
Rijgersberg
changed pull request title from
Load model in bloat16
to Load model in bfloat16
BramVanroy
changed pull request status to
merged
Thanks! I ignorantly assumed that dtype=auto would take care of this when the safetensors metadata is all BF16.