gugarosa commited on
Commit
1a4c7ae
1 Parent(s): accfee5

Update modeling_phi.py

Browse files
Files changed (1) hide show
  1. modeling_phi.py +4 -4
modeling_phi.py CHANGED
@@ -302,6 +302,9 @@ class PhiAttention(nn.Module):
302
  else:
303
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
304
 
 
 
 
305
  def forward(
306
  self,
307
  hidden_states: torch.Tensor,
@@ -359,10 +362,7 @@ class PhiAttention(nn.Module):
359
  key_states = repeat_kv(key_states, self.num_key_value_groups)
360
  value_states = repeat_kv(value_states, self.num_key_value_groups)
361
 
362
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
363
- attn_weights = torch.matmul(
364
- query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
365
- ) / math.sqrt(self.head_dim)
366
 
367
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
368
  raise ValueError(
 
302
  else:
303
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
304
 
305
+ # Phi-2 has an attention overflow issue (with FP16) and requires autocast to be disabled
306
+ @torch.autocast("cpu", enabled=False)
307
+ @torch.autocast("cuda", enabled=False)
308
  def forward(
309
  self,
310
  hidden_states: torch.Tensor,
 
362
  key_states = repeat_kv(key_states, self.num_key_value_groups)
363
  value_states = repeat_kv(value_states, self.num_key_value_groups)
364
 
365
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
 
366
 
367
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
368
  raise ValueError(