x54-729 commited on
Commit
281d770
1 Parent(s): 6e1fdc1

remove unnecessary attention_drop

Browse files
Files changed (1) hide show
  1. modeling_internlm.py +1 -3
modeling_internlm.py CHANGED
@@ -417,10 +417,8 @@ class InternLMFlashAttention2(InternLMAttention):
417
  key_states = key_states.transpose(1, 2)
418
  value_states = value_states.transpose(1, 2)
419
 
420
- dropout_rate = 0.0 if not self.training else self.attention_dropout
421
-
422
  attn_output = self._flash_attention_forward(
423
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
424
  )
425
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
426
  attn_output = self.o_proj(attn_output)
 
417
  key_states = key_states.transpose(1, 2)
418
  value_states = value_states.transpose(1, 2)
419
 
 
 
420
  attn_output = self._flash_attention_forward(
421
+ query_states, key_states, value_states, attention_mask, q_len
422
  )
423
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
424
  attn_output = self.o_proj(attn_output)