winglian commited on
Commit
60c7c48
1 Parent(s): e8cbf50

update for recent transformers updates (#636)

Browse files

* update for recent transformers updates

* fix checkpoint forward kwargs

* just pass args into torch checkpoint

src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -99,6 +99,7 @@ def flashattn_forward(
99
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
100
  output_attentions: bool = False,
101
  use_cache: bool = False,
 
102
  cu_seqlens: Optional[torch.Tensor] = None,
103
  max_seqlen: Optional[torch.Tensor] = None,
104
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -476,6 +477,13 @@ def llama_model_forward(
476
  dtype=torch.bool,
477
  device=inputs_embeds.device,
478
  )
 
 
 
 
 
 
 
479
  attention_mask = (
480
  self._prepare_decoder_attention_mask( # pylint: disable=protected-access
481
  attention_mask,
@@ -510,7 +518,9 @@ def llama_model_forward(
510
  def create_custom_forward(module):
511
  def custom_forward(*inputs):
512
  # None for past_key_value
513
- return module(*inputs)
 
 
514
 
515
  return custom_forward
516
 
@@ -519,9 +529,10 @@ def llama_model_forward(
519
  hidden_states,
520
  attention_mask,
521
  position_ids,
522
- None,
523
  output_attentions,
524
  None,
 
525
  cu_seqlens,
526
  max_seqlen,
527
  )
@@ -533,6 +544,7 @@ def llama_model_forward(
533
  past_key_value=past_key_value,
534
  output_attentions=output_attentions,
535
  use_cache=use_cache,
 
536
  cu_seqlens=cu_seqlens,
537
  max_seqlen=max_seqlen,
538
  )
@@ -579,6 +591,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
579
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
580
  output_attentions: Optional[bool] = False,
581
  use_cache: Optional[bool] = False,
 
582
  cu_seqlens: Optional[torch.Tensor] = None,
583
  max_seqlen: Optional[torch.Tensor] = None,
584
  ) -> Tuple[
@@ -611,6 +624,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
611
  past_key_value=past_key_value,
612
  output_attentions=output_attentions,
613
  use_cache=use_cache,
 
614
  cu_seqlens=cu_seqlens,
615
  max_seqlen=max_seqlen,
616
  )
 
99
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
100
  output_attentions: bool = False,
101
  use_cache: bool = False,
102
+ padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
103
  cu_seqlens: Optional[torch.Tensor] = None,
104
  max_seqlen: Optional[torch.Tensor] = None,
105
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
477
  dtype=torch.bool,
478
  device=inputs_embeds.device,
479
  )
480
+ padding_mask = None
481
+ else:
482
+ if 0 in attention_mask:
483
+ padding_mask = attention_mask
484
+ else:
485
+ padding_mask = None
486
+
487
  attention_mask = (
488
  self._prepare_decoder_attention_mask( # pylint: disable=protected-access
489
  attention_mask,
 
518
  def create_custom_forward(module):
519
  def custom_forward(*inputs):
520
  # None for past_key_value
521
+ return module(
522
+ *inputs,
523
+ )
524
 
525
  return custom_forward
526
 
 
529
  hidden_states,
530
  attention_mask,
531
  position_ids,
532
+ past_key_value,
533
  output_attentions,
534
  None,
535
+ padding_mask,
536
  cu_seqlens,
537
  max_seqlen,
538
  )
 
544
  past_key_value=past_key_value,
545
  output_attentions=output_attentions,
546
  use_cache=use_cache,
547
+ padding_mask=padding_mask,
548
  cu_seqlens=cu_seqlens,
549
  max_seqlen=max_seqlen,
550
  )
 
591
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
592
  output_attentions: Optional[bool] = False,
593
  use_cache: Optional[bool] = False,
594
+ padding_mask: Optional[torch.LongTensor] = None,
595
  cu_seqlens: Optional[torch.Tensor] = None,
596
  max_seqlen: Optional[torch.Tensor] = None,
597
  ) -> Tuple[
 
624
  past_key_value=past_key_value,
625
  output_attentions=output_attentions,
626
  use_cache=use_cache,
627
+ padding_mask=padding_mask,
628
  cu_seqlens=cu_seqlens,
629
  max_seqlen=max_seqlen,
630
  )