HyperAccel commited on
Commit
530ff1a
·
verified ·
1 Parent(s): aea997a

Upload tiny-random deepseek_v32 model

Browse files
configuration_deepseek_v32.py CHANGED
@@ -98,6 +98,12 @@ class DeepseekV32Config(PretrainedConfig):
98
  Whether to use a bias in the query, key, value and output projection layers during self-attention.
99
  attention_dropout (`float`, *optional*, defaults to 0.0):
100
  The dropout ratio for the attention probabilities.
 
 
 
 
 
 
101
 
102
  ```python
103
  >>> from transformers import DeepseekV32Model, DeepseekV32Config
@@ -152,6 +158,9 @@ class DeepseekV32Config(PretrainedConfig):
152
  rope_scaling=None,
153
  attention_bias=False,
154
  attention_dropout=0.0,
 
 
 
155
  **kwargs,
156
  ):
157
  self.vocab_size = vocab_size
@@ -192,6 +201,9 @@ class DeepseekV32Config(PretrainedConfig):
192
  self.rope_scaling = rope_scaling
193
  self.attention_bias = attention_bias
194
  self.attention_dropout = attention_dropout
 
 
 
195
 
196
  super().__init__(
197
  pad_token_id=pad_token_id,
 
98
  Whether to use a bias in the query, key, value and output projection layers during self-attention.
99
  attention_dropout (`float`, *optional*, defaults to 0.0):
100
  The dropout ratio for the attention probabilities.
101
+ index_n_heads (`int`, *optional*, defaults to 64):
102
+ Number of attention heads used in the sparse attention indexer.
103
+ index_head_dim (`int`, *optional*, defaults to 128):
104
+ Dimension of each head in the sparse attention indexer.
105
+ index_topk (`int`, *optional*, defaults to 2048):
106
+ Number of top-k key-value positions selected by the sparse attention indexer.
107
 
108
  ```python
109
  >>> from transformers import DeepseekV32Model, DeepseekV32Config
 
158
  rope_scaling=None,
159
  attention_bias=False,
160
  attention_dropout=0.0,
161
+ index_n_heads=64,
162
+ index_head_dim=128,
163
+ index_topk=2048,
164
  **kwargs,
165
  ):
166
  self.vocab_size = vocab_size
 
201
  self.rope_scaling = rope_scaling
202
  self.attention_bias = attention_bias
203
  self.attention_dropout = attention_dropout
204
+ self.index_n_heads = index_n_heads
205
+ self.index_head_dim = index_head_dim
206
+ self.index_topk = index_topk
207
 
208
  super().__init__(
209
  pad_token_id=pad_token_id,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e7630389ac118c34d12846521ff5102b4ba0b97fa733ad63eb780c38aed731f0
3
- size 545819392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa9a2cfe783c2448f7d7e3d0d2149fe388261955b0c412f1f343a15ba17369e2
3
+ size 546248736
modeling_deepseek_v32.py CHANGED
@@ -336,39 +336,39 @@ def rotate_half(x):
336
  return torch.cat((-x2, x1), dim=-1)
337
 
338
 
339
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
340
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
341
- """Applies Rotary Position Embedding to the query and key tensors.
 
 
342
 
343
  Args:
344
- q (`torch.Tensor`): The query tensor.
345
- k (`torch.Tensor`): The key tensor.
346
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
347
- sin (`torch.Tensor`): The sine part of the rotary embedding.
348
- position_ids (`torch.Tensor`):
349
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
350
- used to pass offsetted position ids when working with a KV-cache.
351
- unsqueeze_dim (`int`, *optional*, defaults to 1):
352
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
353
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
354
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
355
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
356
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
357
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
358
- Returns:
359
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
360
  """
361
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
362
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
363
-
364
- b, h, s, d = q.shape
365
- q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
 
 
 
 
 
 
 
 
 
366
 
367
- b, h, s, d = k.shape
368
- k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
369
 
370
- q_embed = (q * cos) + (rotate_half(q) * sin)
371
- k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
372
  return q_embed, k_embed
373
 
374
 
@@ -610,6 +610,128 @@ class DeepseekV32MoE(nn.Module):
610
  return final_out
611
 
612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
614
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
615
  """
@@ -696,6 +818,9 @@ class DeepseekV32Attention(nn.Module):
696
  mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
697
  self.softmax_scale = self.softmax_scale * mscale * mscale
698
 
 
 
 
699
  def _init_rope(self):
700
  if self.config.rope_scaling is None:
701
  self.rotary_emb = DeepseekV32RotaryEmbedding(
@@ -767,8 +892,10 @@ class DeepseekV32Attention(nn.Module):
767
 
768
  if self.q_lora_rank is None:
769
  q = self.q_proj(hidden_states)
 
770
  else:
771
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
 
772
  q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
773
  q_nope, q_pe = torch.split(
774
  q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
@@ -823,12 +950,27 @@ class DeepseekV32Attention(nn.Module):
823
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
824
  f" {attn_weights.size()}"
825
  )
826
- assert attention_mask is not None
827
  if attention_mask is not None:
828
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
829
  raise ValueError(
830
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
831
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832
  attn_weights = attn_weights + attention_mask
833
 
834
  # upcast attention to fp32
@@ -903,7 +1045,8 @@ class DeepseekV32FlashAttention2(DeepseekV32Attention):
903
  if self.q_lora_rank is None:
904
  q = self.q_proj(hidden_states)
905
  else:
906
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
 
907
  q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
908
  q_nope, q_pe = torch.split(
909
  q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
 
336
  return torch.cat((-x2, x1), dim=-1)
337
 
338
 
339
+ def apply_rotary_emb(x, cos, sin, position_ids, unsqueeze_dim=1, interleaved=True):
340
+ """Applies rotary positional embeddings using complex number operations.
341
+
342
+ Matches DeepSeek-V3.2-Exp/inference/model.py apply_rotary_emb:
343
+ Uses view_as_complex / view_as_real for the rotation.
344
 
345
  Args:
346
+ x: Input tensor [batch, heads, seq_len, rope_dim]
347
+ cos, sin: Cached cos/sin values [seq_len, rope_dim]
348
+ position_ids: Position indices [batch, seq_len]
349
+ unsqueeze_dim: Dimension to unsqueeze for broadcasting (default 1 for heads dim)
350
+ interleaved: If True, consecutive pairs are (real, imag). If False, first half real, second half imag.
 
 
 
 
 
 
 
 
 
 
 
351
  """
352
+ dtype = x.dtype
353
+ shape = x.shape
354
+ half = cos.shape[-1] // 2
355
+ cos_pos = cos[position_ids][..., :half].unsqueeze(unsqueeze_dim)
356
+ sin_pos = sin[position_ids][..., :half].unsqueeze(unsqueeze_dim)
357
+ freqs_cis = torch.complex(cos_pos, sin_pos)
358
+
359
+ if not interleaved:
360
+ x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
361
+ x_complex = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
362
+ y = torch.view_as_real(x_complex * freqs_cis).flatten(-2)
363
+ if not interleaved:
364
+ y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
365
+ return y.to(dtype)
366
 
 
 
367
 
368
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
369
+ """Applies Rotary Position Embedding to the query and key tensors (interleaved format)."""
370
+ q_embed = apply_rotary_emb(q, cos, sin, position_ids, unsqueeze_dim, interleaved=True)
371
+ k_embed = apply_rotary_emb(k, cos, sin, position_ids, unsqueeze_dim, interleaved=True)
372
  return q_embed, k_embed
373
 
374
 
 
610
  return final_out
611
 
612
 
613
+ def hadamard_transform(x: torch.Tensor, scale: float) -> torch.Tensor:
614
+ """Pure PyTorch Hadamard transform via butterfly decomposition."""
615
+ n = x.size(-1)
616
+ h = 1
617
+ while h < n:
618
+ x = x.unflatten(-1, (-1, h * 2))
619
+ a = x[..., :h]
620
+ b = x[..., h:]
621
+ x = torch.cat([a + b, a - b], dim=-1).flatten(-2)
622
+ h *= 2
623
+ return x * scale
624
+
625
+
626
+ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
627
+ """Applies Hadamard transform to distribute magnitudes evenly across dimensions."""
628
+ hidden_size = x.size(-1)
629
+ return hadamard_transform(x, scale=hidden_size ** -0.5)
630
+
631
+
632
+ class DeepseekV32Indexer(nn.Module):
633
+ """
634
+ Sparse attention indexer for DeepSeek V3.2.
635
+ Selects top-k key-value positions to attend to, enabling efficient sparse attention.
636
+ """
637
+
638
+ def __init__(self, config: DeepseekV32Config):
639
+ super().__init__()
640
+ self.hidden_size = config.hidden_size
641
+ self.n_heads = config.index_n_heads
642
+ self.head_dim = config.index_head_dim
643
+ self.qk_rope_head_dim = config.qk_rope_head_dim
644
+ self.index_topk = config.index_topk
645
+ self.q_lora_rank = config.q_lora_rank
646
+
647
+ # Query projection from compressed q (q_lora_rank) to indexer heads
648
+ self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.head_dim, bias=False)
649
+ # Key projection from hidden states
650
+ self.wk = nn.Linear(self.hidden_size, self.head_dim, bias=False)
651
+ self.k_norm = nn.LayerNorm(self.head_dim)
652
+ # Importance weighting projection
653
+ self.weights_proj = nn.Linear(self.hidden_size, self.n_heads, bias=False)
654
+
655
+ self.softmax_scale = self.head_dim ** -0.5
656
+
657
+ def forward(
658
+ self,
659
+ hidden_states: torch.Tensor,
660
+ compressed_q: torch.Tensor,
661
+ cos: torch.Tensor,
662
+ sin: torch.Tensor,
663
+ position_ids: torch.LongTensor,
664
+ attention_mask: Optional[torch.Tensor] = None,
665
+ ) -> torch.Tensor:
666
+ """
667
+ Args:
668
+ hidden_states: Input hidden states [batch, seq_len, hidden_size]
669
+ compressed_q: Compressed query from q_a_layernorm(q_a_proj(x)) [batch, seq_len, q_lora_rank]
670
+ cos, sin: Rotary embedding cos/sin values
671
+ position_ids: Position IDs
672
+ attention_mask: Attention mask [batch, 1, seq_len, seq_len]
673
+
674
+ Returns:
675
+ topk_indices: Indices of top-k positions to attend to [batch, seq_len, index_topk]
676
+ """
677
+ bsz, q_len, _ = hidden_states.size()
678
+
679
+ # Compute indexer queries
680
+ q = self.wq_b(compressed_q)
681
+ q = q.view(bsz, q_len, self.n_heads, self.head_dim)
682
+ # Split into rope and non-rope parts
683
+ q_pe = q[..., :self.qk_rope_head_dim]
684
+ q_nope = q[..., self.qk_rope_head_dim:]
685
+
686
+ # Apply RoPE to query (non-interleaved in indexer, matching reference)
687
+ q_pe = q_pe.transpose(1, 2) # [bsz, n_heads, q_len, rope_dim]
688
+ q_pe = apply_rotary_emb(q_pe, cos, sin, position_ids, unsqueeze_dim=1, interleaved=False)
689
+ q_pe = q_pe.transpose(1, 2) # back to [bsz, q_len, n_heads, rope_dim]
690
+
691
+ q = torch.cat([q_pe, q_nope], dim=-1) # [bsz, q_len, n_heads, head_dim]
692
+
693
+ # Compute indexer keys
694
+ k = self.wk(hidden_states) # [bsz, q_len, head_dim]
695
+ k = self.k_norm(k)
696
+ k_pe = k[..., :self.qk_rope_head_dim]
697
+ k_nope = k[..., self.qk_rope_head_dim:]
698
+
699
+ # Apply RoPE to key (non-interleaved in indexer, matching reference)
700
+ k_pe = k_pe.unsqueeze(1) # [bsz, 1, q_len, rope_dim]
701
+ k_pe = apply_rotary_emb(k_pe, cos, sin, position_ids, unsqueeze_dim=1, interleaved=False)
702
+ k_pe = k_pe.squeeze(1) # [bsz, q_len, rope_dim]
703
+
704
+ k = torch.cat([k_pe, k_nope], dim=-1) # [bsz, q_len, head_dim]
705
+
706
+ # Apply Hadamard transform (from DeepSeek-V3.2-Exp/inference/model.py)
707
+ q = rotate_activation(q)
708
+ k = rotate_activation(k)
709
+
710
+ # Compute importance weights
711
+ weights = self.weights_proj(hidden_states.float()) * (self.n_heads ** -0.5) # [bsz, q_len, n_heads]
712
+
713
+ # Compute index scores: q @ k^T scaled by weights
714
+ # q: [bsz, q_len, n_heads, head_dim], k: [bsz, q_len, head_dim]
715
+ # scores: [bsz, q_len, q_len] - sum over heads of (q_i @ k_j * weight_i)
716
+ q = q.transpose(1, 2) # [bsz, n_heads, q_len, head_dim]
717
+ k_expanded = k.unsqueeze(1) # [bsz, 1, q_len, head_dim]
718
+ index_score = torch.matmul(q, k_expanded.transpose(-1, -2)) # [bsz, n_heads, q_len, q_len]
719
+ index_score = index_score * self.softmax_scale
720
+ # Weight by importance: weights is [bsz, q_len, n_heads] -> [bsz, n_heads, q_len, 1]
721
+ weights = weights.permute(0, 2, 1).unsqueeze(-1)
722
+ index_score = (index_score * weights).sum(dim=1) # [bsz, q_len, q_len]
723
+
724
+ if attention_mask is not None:
725
+ # attention_mask shape: [bsz, 1, q_len, kv_len]
726
+ index_score = index_score + attention_mask.squeeze(1)
727
+
728
+ # Select top-k indices
729
+ topk = min(self.index_topk, q_len)
730
+ topk_indices = index_score.topk(topk, dim=-1)[1] # [bsz, q_len, topk]
731
+
732
+ return topk_indices
733
+
734
+
735
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
736
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
737
  """
 
818
  mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
819
  self.softmax_scale = self.softmax_scale * mscale * mscale
820
 
821
+ # DeepSeek V3.2 Sparse Attention Indexer
822
+ self.indexer = DeepseekV32Indexer(config)
823
+
824
  def _init_rope(self):
825
  if self.config.rope_scaling is None:
826
  self.rotary_emb = DeepseekV32RotaryEmbedding(
 
892
 
893
  if self.q_lora_rank is None:
894
  q = self.q_proj(hidden_states)
895
+ compressed_q = None
896
  else:
897
+ compressed_q = self.q_a_layernorm(self.q_a_proj(hidden_states))
898
+ q = self.q_b_proj(compressed_q)
899
  q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
900
  q_nope, q_pe = torch.split(
901
  q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
 
950
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
951
  f" {attn_weights.size()}"
952
  )
 
953
  if attention_mask is not None:
954
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
955
  raise ValueError(
956
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
957
  )
958
+
959
+ # DeepSeek V3.2: Apply sparse attention indexer mask (includes causal mask)
960
+ # Matching reference: causal mask is applied only once, via index_mask
961
+ if compressed_q is not None:
962
+ topk_indices = self.indexer(
963
+ hidden_states, compressed_q, cos, sin, position_ids, attention_mask
964
+ )
965
+ # Create sparse index mask: only attend to top-k positions
966
+ index_mask = torch.full(
967
+ (bsz, q_len, kv_seq_len), float("-inf"), device=hidden_states.device
968
+ )
969
+ index_mask.scatter_(-1, topk_indices, 0.0)
970
+ if attention_mask is not None:
971
+ index_mask = index_mask + attention_mask.squeeze(1)
972
+ attn_weights = attn_weights + index_mask.unsqueeze(1)
973
+ elif attention_mask is not None:
974
  attn_weights = attn_weights + attention_mask
975
 
976
  # upcast attention to fp32
 
1045
  if self.q_lora_rank is None:
1046
  q = self.q_proj(hidden_states)
1047
  else:
1048
+ compressed_q = self.q_a_layernorm(self.q_a_proj(hidden_states))
1049
+ q = self.q_b_proj(compressed_q)
1050
  q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
1051
  q_nope, q_pe = torch.split(
1052
  q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1