update for recent transformers updates (#636)
Browse files* update for recent transformers updates
* fix checkpoint forward kwargs
* just pass args into torch checkpoint
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
@@ -99,6 +99,7 @@ def flashattn_forward(
|
|
99 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
100 |
output_attentions: bool = False,
|
101 |
use_cache: bool = False,
|
|
|
102 |
cu_seqlens: Optional[torch.Tensor] = None,
|
103 |
max_seqlen: Optional[torch.Tensor] = None,
|
104 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
@@ -476,6 +477,13 @@ def llama_model_forward(
|
|
476 |
dtype=torch.bool,
|
477 |
device=inputs_embeds.device,
|
478 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
attention_mask = (
|
480 |
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
481 |
attention_mask,
|
@@ -510,7 +518,9 @@ def llama_model_forward(
|
|
510 |
def create_custom_forward(module):
|
511 |
def custom_forward(*inputs):
|
512 |
# None for past_key_value
|
513 |
-
return module(
|
|
|
|
|
514 |
|
515 |
return custom_forward
|
516 |
|
@@ -519,9 +529,10 @@ def llama_model_forward(
|
|
519 |
hidden_states,
|
520 |
attention_mask,
|
521 |
position_ids,
|
522 |
-
|
523 |
output_attentions,
|
524 |
None,
|
|
|
525 |
cu_seqlens,
|
526 |
max_seqlen,
|
527 |
)
|
@@ -533,6 +544,7 @@ def llama_model_forward(
|
|
533 |
past_key_value=past_key_value,
|
534 |
output_attentions=output_attentions,
|
535 |
use_cache=use_cache,
|
|
|
536 |
cu_seqlens=cu_seqlens,
|
537 |
max_seqlen=max_seqlen,
|
538 |
)
|
@@ -579,6 +591,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|
579 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
580 |
output_attentions: Optional[bool] = False,
|
581 |
use_cache: Optional[bool] = False,
|
|
|
582 |
cu_seqlens: Optional[torch.Tensor] = None,
|
583 |
max_seqlen: Optional[torch.Tensor] = None,
|
584 |
) -> Tuple[
|
@@ -611,6 +624,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|
611 |
past_key_value=past_key_value,
|
612 |
output_attentions=output_attentions,
|
613 |
use_cache=use_cache,
|
|
|
614 |
cu_seqlens=cu_seqlens,
|
615 |
max_seqlen=max_seqlen,
|
616 |
)
|
|
|
99 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
100 |
output_attentions: bool = False,
|
101 |
use_cache: bool = False,
|
102 |
+
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
103 |
cu_seqlens: Optional[torch.Tensor] = None,
|
104 |
max_seqlen: Optional[torch.Tensor] = None,
|
105 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
477 |
dtype=torch.bool,
|
478 |
device=inputs_embeds.device,
|
479 |
)
|
480 |
+
padding_mask = None
|
481 |
+
else:
|
482 |
+
if 0 in attention_mask:
|
483 |
+
padding_mask = attention_mask
|
484 |
+
else:
|
485 |
+
padding_mask = None
|
486 |
+
|
487 |
attention_mask = (
|
488 |
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
489 |
attention_mask,
|
|
|
518 |
def create_custom_forward(module):
|
519 |
def custom_forward(*inputs):
|
520 |
# None for past_key_value
|
521 |
+
return module(
|
522 |
+
*inputs,
|
523 |
+
)
|
524 |
|
525 |
return custom_forward
|
526 |
|
|
|
529 |
hidden_states,
|
530 |
attention_mask,
|
531 |
position_ids,
|
532 |
+
past_key_value,
|
533 |
output_attentions,
|
534 |
None,
|
535 |
+
padding_mask,
|
536 |
cu_seqlens,
|
537 |
max_seqlen,
|
538 |
)
|
|
|
544 |
past_key_value=past_key_value,
|
545 |
output_attentions=output_attentions,
|
546 |
use_cache=use_cache,
|
547 |
+
padding_mask=padding_mask,
|
548 |
cu_seqlens=cu_seqlens,
|
549 |
max_seqlen=max_seqlen,
|
550 |
)
|
|
|
591 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
592 |
output_attentions: Optional[bool] = False,
|
593 |
use_cache: Optional[bool] = False,
|
594 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
595 |
cu_seqlens: Optional[torch.Tensor] = None,
|
596 |
max_seqlen: Optional[torch.Tensor] = None,
|
597 |
) -> Tuple[
|
|
|
624 |
past_key_value=past_key_value,
|
625 |
output_attentions=output_attentions,
|
626 |
use_cache=use_cache,
|
627 |
+
padding_mask=padding_mask,
|
628 |
cu_seqlens=cu_seqlens,
|
629 |
max_seqlen=max_seqlen,
|
630 |
)
|