Fix auto-casting with SFT sometimes only up-casting keys and not queries

#56
Files changed (1) hide show
  1. modeling_phi3.py +1 -1
modeling_phi3.py CHANGED
@@ -546,7 +546,7 @@ class Phi3FlashAttention2(Phi3Attention):
546
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
547
  # in fp32.
548
 
549
- if query_states.dtype == torch.float32:
550
  if torch.is_autocast_enabled():
551
  target_dtype = torch.get_autocast_gpu_dtype()
552
  # Handle the case where the model is quantized
 
546
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
547
  # in fp32.
548
 
549
+ if query_states.dtype == torch.float32 or key_states.dtype == torch.float32:
550
  if torch.is_autocast_enabled():
551
  target_dtype = torch.get_autocast_gpu_dtype()
552
  # Handle the case where the model is quantized