visheratin commited on
Commit
122865a
1 Parent(s): 0953ba5

Update new model

Browse files
Files changed (1) hide show
  1. modeling_phi.py +54 -38
modeling_phi.py CHANGED
@@ -24,11 +24,14 @@ try:
24
  from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
25
  from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
26
  from flash_attn.ops.fused_dense import FusedDense
27
- except:
 
 
28
  pad_input, unpad_input = None, None
29
  FlashRotaryEmbedding = None
30
  FlashSelfAttention, FlashCrossAttention = None, None
31
  FusedDense = None
 
32
 
33
 
34
  @dataclass
@@ -525,7 +528,7 @@ class MHA(nn.Module):
525
  softmax_scale: Optional[float] = None,
526
  layer_idx: Optional[int] = None,
527
  return_residual: bool = False,
528
- checkpointing: bool = False,
529
  ) -> None:
530
  super().__init__()
531
 
@@ -607,7 +610,7 @@ class MHA(nn.Module):
607
 
608
  if self.checkpointing:
609
  attn_output = torch.utils.checkpoint.checkpoint(
610
- self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
611
  )
612
  else:
613
  attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
@@ -616,7 +619,7 @@ class MHA(nn.Module):
616
  return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
617
 
618
  if self.checkpointing:
619
- return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
620
 
621
  return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
622
 
@@ -669,11 +672,12 @@ class MHA(nn.Module):
669
  self.inner_cross_attn,
670
  q,
671
  kv,
672
- causal=causal,
673
- cu_seqlens=cu_seqlens_q,
674
- max_seqlen=max_seqlen_q,
675
- cu_seqlens_k=cu_seqlens_k,
676
- max_seqlen_k=max_seqlen_k,
 
677
  )
678
  else:
679
  attn_output = self.inner_cross_attn(
@@ -697,8 +701,9 @@ class MHA(nn.Module):
697
  self.inner_cross_attn,
698
  q,
699
  kv,
700
- key_padding_mask=key_padding_mask,
701
- causal=causal,
 
702
  )
703
 
704
  return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
@@ -835,7 +840,7 @@ class PhiPreTrainedModel(PreTrainedModel):
835
 
836
  config_class = PhiConfig
837
  base_model_prefix = "transformer"
838
- supports_gradient_checkpointing = False
839
  _no_split_modules = ["ParallelBlock"]
840
 
841
  def __init__(self, *inputs, **kwargs) -> None:
@@ -862,20 +867,20 @@ class PhiPreTrainedModel(PreTrainedModel):
862
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
863
  **kwargs,
864
  ) -> Dict[str, Any]:
865
- if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
866
- past_key_values = InferenceParams(
867
- max_seqlen=self.config.n_positions,
868
- max_batch_size=input_ids.shape[0],
869
- seqlen_offset=0,
870
- batch_size_offset=0,
871
- key_value_memory_dict={},
872
- lengths_per_sample=None,
873
- )
874
- else:
875
- # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
876
- past_key_values.seqlen_offset = input_ids.shape[1] - 1
877
- input_ids = input_ids[:, -1].unsqueeze(-1)
878
- attention_mask = attention_mask[:, -1].unsqueeze(-1)
879
 
880
  return {
881
  "input_ids": input_ids,
@@ -891,17 +896,19 @@ class PhiModel(PhiPreTrainedModel):
891
  _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
892
 
893
  def __init__(self, config: PhiConfig) -> None:
 
 
894
  super().__init__(config)
895
 
896
  self.embd = Embedding(config)
897
  self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
898
- self.gradient_checkpointing = False
899
  self.post_init()
900
 
901
- def get_input_embeddings(self):
902
- return self.embd
903
 
904
- def set_input_embeddings(self, new_embeddings) -> None:
905
  self.embd.wte = new_embeddings
906
 
907
  def forward(
@@ -919,11 +926,20 @@ class PhiModel(PhiPreTrainedModel):
919
  raise ValueError("You have to specify either input_ids or inputs_embeds")
920
 
921
  for layer in self.h:
922
- hidden_states = layer(
923
- hidden_states,
924
- past_key_values=past_key_values,
925
- attention_mask=attention_mask,
926
- )
 
 
 
 
 
 
 
 
 
927
 
928
  return hidden_states
929
 
@@ -947,10 +963,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
947
 
948
  self.post_init()
949
 
950
- def get_output_embeddings(self):
951
- return self.lm_head
952
 
953
- def set_output_embeddings(self, new_embeddings) -> None:
954
  self.lm_head.linear = new_embeddings
955
 
956
  def forward(
 
24
  from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
25
  from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
26
  from flash_attn.ops.fused_dense import FusedDense
27
+ print("Using Flash Attention!")
28
+ except Exception as exc:
29
+ print(exc)
30
  pad_input, unpad_input = None, None
31
  FlashRotaryEmbedding = None
32
  FlashSelfAttention, FlashCrossAttention = None, None
33
  FusedDense = None
34
+ print("Not using Flash Attention!")
35
 
36
 
37
  @dataclass
 
528
  softmax_scale: Optional[float] = None,
529
  layer_idx: Optional[int] = None,
530
  return_residual: bool = False,
531
+ checkpointing: bool = True,
532
  ) -> None:
533
  super().__init__()
534
 
 
610
 
611
  if self.checkpointing:
612
  attn_output = torch.utils.checkpoint.checkpoint(
613
+ self.inner_attn, qkv, None, cu_seqlens, max_seqlen, use_reentrant=False
614
  )
615
  else:
616
  attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
 
619
  return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
620
 
621
  if self.checkpointing:
622
+ return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, None, key_padding_mask, use_reentrant=False)
623
 
624
  return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
625
 
 
672
  self.inner_cross_attn,
673
  q,
674
  kv,
675
+ causal,
676
+ cu_seqlens_q,
677
+ max_seqlen_q,
678
+ cu_seqlens_k,
679
+ max_seqlen_k,
680
+ use_reentrant=False,
681
  )
682
  else:
683
  attn_output = self.inner_cross_attn(
 
701
  self.inner_cross_attn,
702
  q,
703
  kv,
704
+ causal,
705
+ key_padding_mask,
706
+ use_reentrant=False,
707
  )
708
 
709
  return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
 
840
 
841
  config_class = PhiConfig
842
  base_model_prefix = "transformer"
843
+ supports_gradient_checkpointing = True
844
  _no_split_modules = ["ParallelBlock"]
845
 
846
  def __init__(self, *inputs, **kwargs) -> None:
 
867
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
868
  **kwargs,
869
  ) -> Dict[str, Any]:
870
+ # if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
871
+ # past_key_values = InferenceParams(
872
+ # max_seqlen=self.config.n_positions,
873
+ # max_batch_size=input_ids.shape[0],
874
+ # seqlen_offset=0,
875
+ # batch_size_offset=0,
876
+ # key_value_memory_dict={},
877
+ # lengths_per_sample=None,
878
+ # )
879
+ # else:
880
+ # # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
881
+ # past_key_values.seqlen_offset = input_ids.shape[1] - 1
882
+ # input_ids = input_ids[:, -1].unsqueeze(-1)
883
+ # attention_mask = attention_mask[:, -1].unsqueeze(-1)
884
 
885
  return {
886
  "input_ids": input_ids,
 
896
  _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
897
 
898
  def __init__(self, config: PhiConfig) -> None:
899
+ config.flash_attn = True
900
+ config.flash_rotary = True
901
  super().__init__(config)
902
 
903
  self.embd = Embedding(config)
904
  self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
905
+ self.gradient_checkpointing = True
906
  self.post_init()
907
 
908
+ def get_input_embeddings(self) -> nn.Embedding:
909
+ return self.embd.wte
910
 
911
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
912
  self.embd.wte = new_embeddings
913
 
914
  def forward(
 
926
  raise ValueError("You have to specify either input_ids or inputs_embeds")
927
 
928
  for layer in self.h:
929
+ if self.gradient_checkpointing:
930
+ hidden_states = torch.utils.checkpoint.checkpoint(
931
+ layer.__call__,
932
+ hidden_states,
933
+ past_key_values,
934
+ attention_mask,
935
+ use_reentrant=False,
936
+ )
937
+ else:
938
+ hidden_states = layer(
939
+ hidden_states,
940
+ past_key_values=past_key_values,
941
+ attention_mask=attention_mask,
942
+ )
943
 
944
  return hidden_states
945
 
 
963
 
964
  self.post_init()
965
 
966
+ def get_output_embeddings(self) -> nn.Linear:
967
+ return self.lm_head.linear
968
 
969
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
970
  self.lm_head.linear = new_embeddings
971
 
972
  def forward(