Shaltiel commited on
Commit
4e52556
1 Parent(s): 61a61ce

Upload 2 files

Browse files
configuration_megatron_gpt.py CHANGED
@@ -81,7 +81,7 @@ class MegatronGPTConfig(PretrainedConfig):
81
  Whether to calculate and apply the relative position bias within the attention function.
82
  If this is False, then model.generate will require you to calculate the triangular attention
83
  mask and pass it through in the attention mask.
84
- skip_flash_attention (`bool`, *optional*, defaults to `False`):
85
  When calculating attention, whether to attempt to use flash attention if it's installed, or to always skip and use the regular method.
86
  rope_scaling (`Dict`, *optional*):
87
  Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
@@ -120,7 +120,7 @@ class MegatronGPTConfig(PretrainedConfig):
120
  eos_token_id=2,
121
  tie_word_embeddings=False,
122
  rope_scaling=None,
123
- skip_flash_attention=False,
124
  **kwargs,
125
  ):
126
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -144,7 +144,7 @@ class MegatronGPTConfig(PretrainedConfig):
144
  self.use_cache = use_cache
145
  self.self_attention_relative_position_bias = self_attention_relative_position_bias
146
  self.tie_word_embeddings = tie_word_embeddings
147
- self.skip_flash_attention = skip_flash_attention
148
  self.rope_scaling = rope_scaling
149
  self._rope_scaling_validation()
150
 
 
81
  Whether to calculate and apply the relative position bias within the attention function.
82
  If this is False, then model.generate will require you to calculate the triangular attention
83
  mask and pass it through in the attention mask.
84
+ use_flash_attention (`bool`, *optional*, defaults to `False`):
85
  When calculating attention, whether to attempt to use flash attention if it's installed, or to always skip and use the regular method.
86
  rope_scaling (`Dict`, *optional*):
87
  Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
 
120
  eos_token_id=2,
121
  tie_word_embeddings=False,
122
  rope_scaling=None,
123
+ use_flash_attention=False,
124
  **kwargs,
125
  ):
126
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
144
  self.use_cache = use_cache
145
  self.self_attention_relative_position_bias = self_attention_relative_position_bias
146
  self.tie_word_embeddings = tie_word_embeddings
147
+ self.use_flash_attention = use_flash_attention
148
  self.rope_scaling = rope_scaling
149
  self._rope_scaling_validation()
150
 
modeling_megatron_gpt.py CHANGED
@@ -222,7 +222,7 @@ class MegatronGPTAttention(nn.Module):
222
  present = (key, value) if use_cache else None
223
 
224
  # Compute attention
225
- if not HAS_FLASH or output_attentions or head_mask is not None or self.config.skip_flash_attention:
226
  attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
227
  else:
228
  attn_output = self._flash_attn(query, key, value, attention_mask)
 
222
  present = (key, value) if use_cache else None
223
 
224
  # Compute attention
225
+ if not HAS_FLASH or output_attentions or head_mask is not None or not self.config.use_flash_attention:
226
  attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
227
  else:
228
  attn_output = self._flash_attn(query, key, value, attention_mask)