jon-tow commited on
Commit
1383d99
1 Parent(s): 1b08367

Add `flash-attn v2` support

Browse files
Files changed (1) hide show
  1. modeling_stablelm_epoch.py +249 -19
modeling_stablelm_epoch.py CHANGED
@@ -19,23 +19,48 @@
19
  """ PyTorch StableLM Epoch model. """
20
  from typing import Optional, Tuple, Union
21
  import math
 
22
 
23
  import torch
 
24
  import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import CrossEntropyLoss
 
 
27
  from transformers.modeling_outputs import (
28
  BaseModelOutputWithPast,
29
  CausalLMOutputWithPast,
30
  )
31
  from transformers.modeling_utils import PreTrainedModel
32
- from transformers.utils import logging
 
33
  from .configuration_stablelm_epoch import StableLMEpochConfig
34
 
 
 
 
 
 
 
 
35
 
36
  logger = logging.get_logger(__name__)
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
40
  def _make_causal_mask(
41
  input_ids_shape: torch.Size,
@@ -165,12 +190,14 @@ class Attention(nn.Module):
165
  self.num_key_value_heads = config.num_key_value_heads
166
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
167
  self.max_position_embeddings = config.max_position_embeddings
 
168
 
169
  if (self.head_dim * self.num_heads) != self.hidden_size:
170
  raise ValueError(
171
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
172
  f" and `num_heads`: {self.num_heads})."
173
  )
 
174
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias)
175
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
176
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
@@ -269,10 +296,202 @@ class Attention(nn.Module):
269
  return attn_output, attn_weights, past_key_value
270
 
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  class DecoderLayer(nn.Module):
273
  def __init__(self, config: StableLMEpochConfig):
274
  super().__init__()
275
- self.self_attn = Attention(config)
276
  self.mlp = MLP(config)
277
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
278
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
@@ -328,6 +547,7 @@ class StableLMEpochPreTrainedModel(PreTrainedModel):
328
  supports_gradient_checkpointing = True
329
  _no_split_modules = ["DecoderLayer"]
330
  _skip_keys_device_placement = "past_key_values"
 
331
 
332
  def _init_weights(self, module: nn.Module):
333
  """Initialize the weights"""
@@ -355,6 +575,7 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
355
  self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
356
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
357
 
 
358
  self.gradient_checkpointing = False
359
  # Initialize weights and apply final processing
360
  self.post_init()
@@ -428,10 +649,6 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
428
  seq_length_with_past = seq_length
429
  past_key_values_length = 0
430
 
431
- if past_key_values is not None:
432
- past_key_values_length = past_key_values[0][0].shape[2]
433
- seq_length_with_past = seq_length_with_past + past_key_values_length
434
-
435
  if position_ids is None:
436
  device = input_ids.device if input_ids is not None else inputs_embeds.device
437
  position_ids = torch.arange(
@@ -447,18 +664,22 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
447
  if inputs_embeds is None:
448
  inputs_embeds = self.embed_tokens(input_ids)
449
  # Embed positions
450
- if attention_mask is None:
451
- attention_mask = torch.ones(
452
- (batch_size, seq_length_with_past),
453
- dtype=torch.bool,
454
- device=inputs_embeds.device,
 
 
 
 
 
 
 
 
 
 
455
  )
456
- attention_mask = self._prepare_decoder_attention_mask(
457
- attention_mask,
458
- (batch_size, seq_length),
459
- inputs_embeds,
460
- past_key_values_length,
461
- )
462
 
463
  hidden_states = inputs_embeds
464
 
@@ -643,8 +864,17 @@ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
643
  **kwargs,
644
  ):
645
  # Trim decoder_input_ids if past is used
646
- if past_key_values and past_key_values[0] is not None:
647
- input_ids = input_ids[:, -1:]
 
 
 
 
 
 
 
 
 
648
 
649
  position_ids = kwargs.get("position_ids", None)
650
  if attention_mask is not None and position_ids is None:
 
19
  """ PyTorch StableLM Epoch model. """
20
  from typing import Optional, Tuple, Union
21
  import math
22
+ import warnings
23
 
24
  import torch
25
+ import torch.nn.functional as F
26
  import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import CrossEntropyLoss
29
+
30
+ from transformers.cache_utils import Cache
31
  from transformers.modeling_outputs import (
32
  BaseModelOutputWithPast,
33
  CausalLMOutputWithPast,
34
  )
35
  from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import logging, is_flash_attn_greater_or_equal_2_10
37
+
38
  from .configuration_stablelm_epoch import StableLMEpochConfig
39
 
40
+ try:
41
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
42
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
43
+ except:
44
+ flash_attn_func, flash_attn_varlen_func = None, None
45
+ index_first_axis, pad_input, unpad_input = None, None, None
46
+
47
 
48
  logger = logging.get_logger(__name__)
49
 
50
 
51
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
52
+ def _get_unpad_data(attention_mask):
53
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
54
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
55
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
56
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
57
+ return (
58
+ indices,
59
+ cu_seqlens,
60
+ max_seqlen_in_batch,
61
+ )
62
+
63
+
64
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
65
  def _make_causal_mask(
66
  input_ids_shape: torch.Size,
 
190
  self.num_key_value_heads = config.num_key_value_heads
191
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
192
  self.max_position_embeddings = config.max_position_embeddings
193
+ self.is_causal = True
194
 
195
  if (self.head_dim * self.num_heads) != self.hidden_size:
196
  raise ValueError(
197
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
198
  f" and `num_heads`: {self.num_heads})."
199
  )
200
+
201
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias)
202
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
203
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
 
296
  return attn_output, attn_weights, past_key_value
297
 
298
 
299
+ class FlashAttention2(Attention):
300
+ """
301
+ Reference: https://github.com/huggingface/transformers/blob/5d36025ca13d05151b7a0c761e90d429c4644a30/src/transformers/models/llama/modeling_llama.py#L456
302
+ """
303
+
304
+ def __init__(self, *args, **kwargs):
305
+ super().__init__(*args, **kwargs)
306
+
307
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
308
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
309
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
310
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
311
+
312
+ def forward(
313
+ self,
314
+ hidden_states: torch.Tensor,
315
+ attention_mask: Optional[torch.LongTensor] = None,
316
+ position_ids: Optional[torch.LongTensor] = None,
317
+ past_key_value: Optional[Cache] = None,
318
+ output_attentions: bool = False,
319
+ use_cache: bool = False,
320
+ **kwargs,
321
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
322
+ # FlashAttention2 attention does not support output_attentions
323
+ if "padding_mask" in kwargs:
324
+ warnings.warn(
325
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
326
+ )
327
+
328
+ # overwrite attention_mask with padding_mask
329
+ attention_mask = kwargs.pop("padding_mask")
330
+
331
+ output_attentions = False
332
+
333
+ bsz, q_len, _ = hidden_states.size()
334
+
335
+ query_states = self.q_proj(hidden_states)
336
+ key_states = self.k_proj(hidden_states)
337
+ value_states = self.v_proj(hidden_states)
338
+
339
+ # Flash attention requires the input to have the shape
340
+ # batch_size x seq_length x head_dim x hidden_dim
341
+ # therefore we just need to keep the original shape
342
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
343
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
344
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
345
+
346
+ query_rot = query_states[..., : self.rotary_ndims]
347
+ query_pass = query_states[..., self.rotary_ndims :]
348
+ key_rot = key_states[..., : self.rotary_ndims]
349
+ key_pass = key_states[..., self.rotary_ndims :]
350
+
351
+ kv_seq_len = key_states.shape[-2]
352
+ if past_key_value is not None:
353
+ kv_seq_len += past_key_value[0].shape[-2]
354
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
355
+ query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
356
+
357
+ # [batch_size, num_heads, seq_len, head_dim]
358
+ query_states = torch.cat((query_states, query_pass), dim=-1)
359
+ key_states = torch.cat((key_states, key_pass), dim=-1)
360
+
361
+ if past_key_value is not None:
362
+ # Reuse k, v, self_attention
363
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
364
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
365
+
366
+ past_key_value = (key_states, value_states) if use_cache else None
367
+
368
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
369
+ # to be able to avoid many of these transpose/reshape/view.
370
+ query_states = query_states.transpose(1, 2)
371
+ key_states = key_states.transpose(1, 2)
372
+ value_states = value_states.transpose(1, 2)
373
+
374
+ dropout_rate = self.attention_dropout if self.training else 0.0
375
+
376
+ attn_output = self._flash_attention_forward(
377
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
378
+ )
379
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
380
+ attn_output = self.o_proj(attn_output)
381
+
382
+ if not output_attentions:
383
+ attn_weights = None
384
+
385
+ return attn_output, attn_weights, past_key_value
386
+
387
+ def _flash_attention_forward(
388
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
389
+ ):
390
+ """
391
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
392
+ first unpad the input, then computes the attention scores and pad the final attention scores.
393
+
394
+ Args:
395
+ query_states (`torch.Tensor`):
396
+ Input query states to be passed to Flash Attention API
397
+ key_states (`torch.Tensor`):
398
+ Input key states to be passed to Flash Attention API
399
+ value_states (`torch.Tensor`):
400
+ Input value states to be passed to Flash Attention API
401
+ attention_mask (`torch.Tensor`):
402
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
403
+ position of padding tokens and 1 for the position of non-padding tokens.
404
+ dropout (`int`, *optional*):
405
+ Attention dropout
406
+ softmax_scale (`float`, *optional*):
407
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
408
+ """
409
+ if not self._flash_attn_uses_top_left_mask:
410
+ causal = self.is_causal
411
+ else:
412
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in FlashAttention2 __init__.
413
+ causal = self.is_causal and query_length != 1
414
+
415
+ # Contains at least one padding token in the sequence
416
+ if attention_mask is not None:
417
+ batch_size = query_states.shape[0]
418
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
419
+ query_states, key_states, value_states, attention_mask, query_length
420
+ )
421
+
422
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
423
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
424
+
425
+ attn_output_unpad = flash_attn_varlen_func(
426
+ query_states,
427
+ key_states,
428
+ value_states,
429
+ cu_seqlens_q=cu_seqlens_q,
430
+ cu_seqlens_k=cu_seqlens_k,
431
+ max_seqlen_q=max_seqlen_in_batch_q,
432
+ max_seqlen_k=max_seqlen_in_batch_k,
433
+ dropout_p=dropout,
434
+ softmax_scale=softmax_scale,
435
+ causal=causal,
436
+ )
437
+
438
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
439
+ else:
440
+ attn_output = flash_attn_func(
441
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
442
+ )
443
+
444
+ return attn_output
445
+
446
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
447
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
448
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
449
+
450
+ key_layer = index_first_axis(
451
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
452
+ )
453
+ value_layer = index_first_axis(
454
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
455
+ )
456
+ if query_length == kv_seq_len:
457
+ query_layer = index_first_axis(
458
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
459
+ )
460
+ cu_seqlens_q = cu_seqlens_k
461
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
462
+ indices_q = indices_k
463
+ elif query_length == 1:
464
+ max_seqlen_in_batch_q = 1
465
+ cu_seqlens_q = torch.arange(
466
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
467
+ ) # There is a memcpy here, that is very bad.
468
+ indices_q = cu_seqlens_q[:-1]
469
+ query_layer = query_layer.squeeze(1)
470
+ else:
471
+ # The -q_len: slice assumes left padding.
472
+ attention_mask = attention_mask[:, -query_length:]
473
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
474
+
475
+ return (
476
+ query_layer,
477
+ key_layer,
478
+ value_layer,
479
+ indices_q,
480
+ (cu_seqlens_q, cu_seqlens_k),
481
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
482
+ )
483
+
484
+
485
+ ATTENTION_CLASSES = {
486
+ "eager": Attention,
487
+ "flash_attention_2": FlashAttention2,
488
+ }
489
+
490
+
491
  class DecoderLayer(nn.Module):
492
  def __init__(self, config: StableLMEpochConfig):
493
  super().__init__()
494
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config)
495
  self.mlp = MLP(config)
496
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
497
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
 
547
  supports_gradient_checkpointing = True
548
  _no_split_modules = ["DecoderLayer"]
549
  _skip_keys_device_placement = "past_key_values"
550
+ _supports_flash_attn_2 = True
551
 
552
  def _init_weights(self, module: nn.Module):
553
  """Initialize the weights"""
 
575
  self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
576
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
577
 
578
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
579
  self.gradient_checkpointing = False
580
  # Initialize weights and apply final processing
581
  self.post_init()
 
649
  seq_length_with_past = seq_length
650
  past_key_values_length = 0
651
 
 
 
 
 
652
  if position_ids is None:
653
  device = input_ids.device if input_ids is not None else inputs_embeds.device
654
  position_ids = torch.arange(
 
664
  if inputs_embeds is None:
665
  inputs_embeds = self.embed_tokens(input_ids)
666
  # Embed positions
667
+ if self._use_flash_attention_2:
668
+ # 2d mask is passed through the layers
669
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
670
+ else:
671
+ if attention_mask is None:
672
+ attention_mask = torch.ones(
673
+ (batch_size, seq_length_with_past),
674
+ dtype=torch.bool,
675
+ device=inputs_embeds.device,
676
+ )
677
+ attention_mask = self._prepare_decoder_attention_mask(
678
+ attention_mask,
679
+ (batch_size, seq_length),
680
+ inputs_embeds,
681
+ past_key_values_length,
682
  )
 
 
 
 
 
 
683
 
684
  hidden_states = inputs_embeds
685
 
 
864
  **kwargs,
865
  ):
866
  # Trim decoder_input_ids if past is used
867
+ if past_key_values is not None:
868
+ past_length = past_key_values[0][0].shape[2]
869
+
870
+ # Some generation methods already pass only the last input ID
871
+ if input_ids.shape[1] > past_length:
872
+ remove_prefix_length = past_length
873
+ else:
874
+ # Default to old behavior: keep only final ID
875
+ remove_prefix_length = input_ids.shape[1] - 1
876
+
877
+ input_ids = input_ids[:, remove_prefix_length:]
878
 
879
  position_ids = kwargs.get("position_ids", None)
880
  if attention_mask is not None and position_ids is None: