cArlIcon commited on
Commit
ee9da3c
1 Parent(s): b5e0701

make flash-attn optional

Browse files
Files changed (1) hide show
  1. modeling_yi.py +10 -12
modeling_yi.py CHANGED
@@ -6,7 +6,6 @@ import torch.utils.checkpoint
6
  from einops import repeat
7
  from torch import nn
8
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
-
10
  from transformers.activations import ACT2FN
11
  from transformers.modeling_outputs import (
12
  BaseModelOutputWithPast,
@@ -18,17 +17,17 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
18
  from transformers.utils import (
19
  add_start_docstrings,
20
  add_start_docstrings_to_model_forward,
21
- is_flash_attn_available,
22
  logging,
23
  replace_return_docstrings,
24
  )
25
 
26
  from .configuration_yi import YiConfig
27
 
28
-
29
- if is_flash_attn_available():
30
  from flash_attn import flash_attn_func
31
-
 
32
 
33
  logger = logging.get_logger(__name__)
34
 
@@ -224,7 +223,6 @@ class YiAttention(nn.Module):
224
  use_cache: bool = False,
225
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
226
  bsz, q_len, _ = hidden_states.size()
227
- flash_attn_available = is_flash_attn_available()
228
 
229
  query_states = self.q_proj(hidden_states).view(
230
  bsz, q_len, self.num_heads, self.head_dim
@@ -237,7 +235,7 @@ class YiAttention(nn.Module):
237
  bsz, q_len, self.num_key_value_heads, self.head_dim
238
  )
239
 
240
- if not flash_attn_available:
241
  if self.num_key_value_groups > 1:
242
  key_states = repeat(
243
  key_states, f"b n h d -> b n (h {self.num_key_value_groups}) d"
@@ -251,13 +249,13 @@ class YiAttention(nn.Module):
251
  key_states = key_states.transpose(1, 2)
252
  value_states = value_states.transpose(1, 2)
253
 
254
- seq_dim = 1 if flash_attn_available else 2
255
  kv_seq_len = key_states.shape[seq_dim]
256
  if past_key_value is not None:
257
  kv_seq_len += past_key_value[0].shape[seq_dim]
258
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
259
  query_states, key_states = apply_rotary_pos_emb(
260
- query_states, key_states, cos, sin, position_ids, flash_attn_available
261
  )
262
 
263
  if past_key_value is not None:
@@ -267,7 +265,7 @@ class YiAttention(nn.Module):
267
 
268
  past_key_value = (key_states, value_states) if use_cache else None
269
 
270
- if flash_attn_available:
271
  attn_output = flash_attn_func(
272
  query_states, key_states, value_states, dropout_p=0.0, causal=True
273
  )
@@ -308,7 +306,7 @@ class YiAttention(nn.Module):
308
  f" {attn_output.size()}"
309
  )
310
 
311
- if not flash_attn_available:
312
  attn_output = attn_output.transpose(1, 2)
313
 
314
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -631,7 +629,7 @@ class YiModel(YiPreTrainedModel):
631
  if inputs_embeds is None:
632
  inputs_embeds = self.embed_tokens(input_ids)
633
 
634
- if not is_flash_attn_available():
635
  # embed positions
636
  if attention_mask is None:
637
  attention_mask = torch.ones(
 
6
  from einops import repeat
7
  from torch import nn
8
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
9
  from transformers.activations import ACT2FN
10
  from transformers.modeling_outputs import (
11
  BaseModelOutputWithPast,
 
17
  from transformers.utils import (
18
  add_start_docstrings,
19
  add_start_docstrings_to_model_forward,
 
20
  logging,
21
  replace_return_docstrings,
22
  )
23
 
24
  from .configuration_yi import YiConfig
25
 
26
+ is_flash_attn_available = True
27
+ try:
28
  from flash_attn import flash_attn_func
29
+ except Exception:
30
+ is_flash_attn_available = False
31
 
32
  logger = logging.get_logger(__name__)
33
 
 
223
  use_cache: bool = False,
224
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
225
  bsz, q_len, _ = hidden_states.size()
 
226
 
227
  query_states = self.q_proj(hidden_states).view(
228
  bsz, q_len, self.num_heads, self.head_dim
 
235
  bsz, q_len, self.num_key_value_heads, self.head_dim
236
  )
237
 
238
+ if not is_flash_attn_available:
239
  if self.num_key_value_groups > 1:
240
  key_states = repeat(
241
  key_states, f"b n h d -> b n (h {self.num_key_value_groups}) d"
 
249
  key_states = key_states.transpose(1, 2)
250
  value_states = value_states.transpose(1, 2)
251
 
252
+ seq_dim = 1 if is_flash_attn_available else 2
253
  kv_seq_len = key_states.shape[seq_dim]
254
  if past_key_value is not None:
255
  kv_seq_len += past_key_value[0].shape[seq_dim]
256
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
257
  query_states, key_states = apply_rotary_pos_emb(
258
+ query_states, key_states, cos, sin, position_ids, is_flash_attn_available
259
  )
260
 
261
  if past_key_value is not None:
 
265
 
266
  past_key_value = (key_states, value_states) if use_cache else None
267
 
268
+ if is_flash_attn_available:
269
  attn_output = flash_attn_func(
270
  query_states, key_states, value_states, dropout_p=0.0, causal=True
271
  )
 
306
  f" {attn_output.size()}"
307
  )
308
 
309
+ if not is_flash_attn_available:
310
  attn_output = attn_output.transpose(1, 2)
311
 
312
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
629
  if inputs_embeds is None:
630
  inputs_embeds = self.embed_tokens(input_ids)
631
 
632
+ if not is_flash_attn_available:
633
  # embed positions
634
  if attention_mask is None:
635
  attention_mask = torch.ones(