Fix InternLM2ForCausalLM does not support Flash Attention 2.0 yet

#3
Files changed (1) hide show
  1. modeling_internlm2.py +2 -0
modeling_internlm2.py CHANGED
@@ -709,6 +709,8 @@ class InternLM2PreTrainedModel(PreTrainedModel):
709
  supports_gradient_checkpointing = True
710
  _no_split_modules = ['InternLM2DecoderLayer']
711
  _skip_keys_device_placement = 'past_key_values'
 
 
712
 
713
  def _init_weights(self, module):
714
  std = self.config.initializer_range
 
709
  supports_gradient_checkpointing = True
710
  _no_split_modules = ['InternLM2DecoderLayer']
711
  _skip_keys_device_placement = 'past_key_values'
712
+ _supports_flash_attn_2 = True
713
+
714
 
715
  def _init_weights(self, module):
716
  std = self.config.initializer_range