Fix auto-casting with SFT sometimes only up-casting keys and not queries
#56
by
roborovski
- opened
- modeling_phi3.py +1 -1
modeling_phi3.py
CHANGED
@@ -546,7 +546,7 @@ class Phi3FlashAttention2(Phi3Attention):
|
|
546 |
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
547 |
# in fp32.
|
548 |
|
549 |
-
if query_states.dtype == torch.float32:
|
550 |
if torch.is_autocast_enabled():
|
551 |
target_dtype = torch.get_autocast_gpu_dtype()
|
552 |
# Handle the case where the model is quantized
|
|
|
546 |
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
547 |
# in fp32.
|
548 |
|
549 |
+
if query_states.dtype == torch.float32 or key_states.dtype == torch.float32:
|
550 |
if torch.is_autocast_enabled():
|
551 |
target_dtype = torch.get_autocast_gpu_dtype()
|
552 |
# Handle the case where the model is quantized
|