gugarosa commited on
Commit
accfee5
1 Parent(s): 39afec1

Update modeling_phi.py

Browse files
Files changed (1) hide show
  1. modeling_phi.py +1 -1
modeling_phi.py CHANGED
@@ -506,7 +506,7 @@ class PhiFlashAttention2(PhiAttention):
506
  value_states = value_states.to(target_dtype)
507
 
508
  attn_output = self._flash_attention_forward(
509
- query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0
510
  )
511
 
512
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
 
506
  value_states = value_states.to(target_dtype)
507
 
508
  attn_output = self._flash_attention_forward(
509
+ query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
510
  )
511
 
512
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()