Implement MLA inference optimizations to DeepseekV2Attention

#12
by sy-chen - opened
Files changed (1) hide show
  1. modeling_deepseek.py +18 -28
modeling_deepseek.py CHANGED
@@ -822,17 +822,10 @@ class DeepseekV2Attention(nn.Module):
822
  compressed_kv, k_pe = torch.split(
823
  compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
824
  )
 
825
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
826
- kv = (
827
- self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
828
- .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
829
- .transpose(1, 2)
830
- )
831
 
832
- k_nope, value_states = torch.split(
833
- kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
834
- )
835
- kv_seq_len = value_states.shape[-2]
836
  if past_key_value is not None:
837
  if self.layer_idx is None:
838
  raise ValueError(
@@ -841,27 +834,22 @@ class DeepseekV2Attention(nn.Module):
841
  "with a layer index."
842
  )
843
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
844
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
845
 
 
846
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
847
 
848
- query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
849
- query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
850
- query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
851
-
852
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
853
- key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
854
- key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
855
  if past_key_value is not None:
856
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
857
- key_states, value_states = past_key_value.update(
858
- key_states, value_states, self.layer_idx, cache_kwargs
859
- )
860
-
861
- attn_weights = (
862
- torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
863
- )
864
-
 
 
865
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
866
  raise ValueError(
867
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
@@ -878,11 +866,13 @@ class DeepseekV2Attention(nn.Module):
878
  # upcast attention to fp32
879
  attn_weights = nn.functional.softmax(
880
  attn_weights, dim=-1, dtype=torch.float32
881
- ).to(query_states.dtype)
882
  attn_weights = nn.functional.dropout(
883
  attn_weights, p=self.attention_dropout, training=self.training
884
  )
885
- attn_output = torch.matmul(attn_weights, value_states)
 
 
886
 
887
  if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
888
  raise ValueError(
@@ -1902,4 +1892,4 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1902
  past_key_values=transformer_outputs.past_key_values,
1903
  hidden_states=transformer_outputs.hidden_states,
1904
  attentions=transformer_outputs.attentions,
1905
- )
 
822
  compressed_kv, k_pe = torch.split(
823
  compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
824
  )
825
+ compressed_kv = self.kv_a_layernorm(compressed_kv)
826
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
 
 
 
 
 
827
 
828
+ kv_seq_len = k_pe.shape[-2]
 
 
 
829
  if past_key_value is not None:
830
  if self.layer_idx is None:
831
  raise ValueError(
 
834
  "with a layer index."
835
  )
836
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
837
 
838
+ cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
839
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
840
 
 
 
 
 
 
 
 
841
  if past_key_value is not None:
842
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
843
+ compressed_kv = compressed_kv.unsqueeze(1)
844
+ k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
845
+ compressed_kv = compressed_kv.squeeze(1)
846
+
847
+ kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
848
+ q_absorb = kv_b_proj[:, :self.qk_nope_head_dim,:]
849
+ out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
850
+
851
+ q_nope = torch.matmul(q_nope, q_absorb)
852
+ attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
853
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
854
  raise ValueError(
855
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
 
866
  # upcast attention to fp32
867
  attn_weights = nn.functional.softmax(
868
  attn_weights, dim=-1, dtype=torch.float32
869
+ ).to(q_pe.dtype)
870
  attn_weights = nn.functional.dropout(
871
  attn_weights, p=self.attention_dropout, training=self.training
872
  )
873
+ attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
874
+
875
+ attn_output = torch.matmul(attn_output, out_absorb.mT)
876
 
877
  if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
878
  raise ValueError(
 
1892
  past_key_values=transformer_outputs.past_key_values,
1893
  hidden_states=transformer_outputs.hidden_states,
1894
  attentions=transformer_outputs.attentions,
1895
+ )