KitsuVp commited on
Commit
d5dfbc7
·
verified ·
1 Parent(s): 87b18c1

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +334 -125
modeling_neollm.py CHANGED
@@ -787,6 +787,9 @@ class LeviathanGenerator(nn.Module):
787
  x_all: [N, M, d_seed], values in [0, 1], all heads stacked.
788
  Returns:
789
  [N, M, d_seed, n_knots] float32.
 
 
 
790
  """
791
  x32 = x_all.float()
792
  x_e = x32.unsqueeze(-1) # [N, M, d_seed, 1]
@@ -801,6 +804,113 @@ class LeviathanGenerator(nn.Module):
801
  torch.where(d < 1.5, 0.5 * (1.5 - d) ** 2, torch.zeros_like(d)),
802
  ) # [N, M, d_seed, n_knots] float32
803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
  def _khronos_all_heads(
805
  self,
806
  B_all: torch.Tensor,
@@ -927,52 +1037,39 @@ class LeviathanGenerator(nn.Module):
927
  analysis.z_tilde = z_tilde.detach()
928
  analysis.B_vals = B_vals.detach()
929
 
930
- # ── Per-head generator path (fully vectorized, 6 kernels) ────────
931
- # All 8 heads are processed simultaneously. No Python loop.
932
- # Maximum intermediate tensor [N, M, d_seed, n_knots] appears once.
933
-
934
- # Kernel 1: fused linear projection for all heads
935
- # z @ W^T [N, M*d_seed] → [N, M, d_seed]
936
- z_all = F.linear(z.to(target_dtype), self.head_proj_weight)
937
- z_all = z_all.view(N, self.num_modes, self.d_seed) # [N, M, d_seed]
938
-
939
- if analysis is not None:
940
- analysis.z_all_pre_norm = z_all.detach()
941
-
942
- # Kernel 2: per-head LayerNorm + sigmoid(x/2)
943
- # Manual LN over last dim with independent weight/bias per head.
944
- # Mathematically identical to 8 separate nn.LayerNorm(d_seed).
945
- mean = z_all.mean(dim=-1, keepdim=True)
946
- var = z_all.var(dim=-1, keepdim=True, unbiased=False)
947
- z_all = (z_all - mean) / (var + self.head_norm_eps).sqrt()
948
- # head_norm_weight/bias: [M, d_seed]broadcast over N
949
- z_all = z_all * self.head_norm_weight.unsqueeze(0) \
950
- + self.head_norm_bias.unsqueeze(0)
951
- z_all = torch.sigmoid(z_all / 2.0) # [N, M, d_seed]
952
-
953
- if analysis is not None:
954
- analysis.z_all_post_sigmoid = z_all.detach()
955
-
956
- # Kernel 3: vectorized B-spline basis for all heads
957
- # head_scale [M, d_seed] is used inside _bspline_basis_all_heads
958
- B_all = self._bspline_basis_all_heads(
959
- z_all.clamp(0.0, 1.0)
960
- ) # [N, M, d_seed, n_knots]
961
-
962
- # Kernel 4: vectorized KHRONOS tensor product for all heads
963
- modes_all = self._khronos_all_heads(B_all) # [N, M, krank]
964
-
965
- if analysis is not None:
966
- analysis.modes_all = modes_all.detach()
967
 
968
- # Kernel 5: project all heads to hidden_size and sum
969
- # einsum: token n, head m, krank k → hidden d (summed over m)
970
- # head_out_weight [M, krank, hidden_size]
971
- e = torch.einsum(
972
- "nmk,mkd->nd",
973
- modes_all.to(target_dtype),
974
- self.head_out_weight.to(target_dtype),
975
- ) # [N, hidden_size]
976
 
977
  # No W_res — confirmed absent in the authors' implementation
978
  e = e.reshape(*orig_shape, self.hidden_size)
@@ -1342,22 +1439,22 @@ class GPAS(nn.Module):
1342
 
1343
  class SeeDNorm(nn.Module):
1344
  """
1345
- Self-Rescaled Dynamic Normalization with dual dropout.
1346
  SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
 
 
 
 
1347
  """
1348
 
1349
  def __init__(
1350
  self,
1351
  dim: int,
1352
  eps: float = 1e-6,
1353
- dropout_input: float = 0.01,
1354
- dropout_hidden: float = 0.01,
1355
  ):
1356
  super().__init__()
1357
- self.dim = dim
1358
- self.eps = eps
1359
- self.dropout_input = dropout_input
1360
- self.dropout_hidden = dropout_hidden
1361
 
1362
  self.gamma = nn.Parameter(torch.ones(dim))
1363
  self.beta = nn.Parameter(torch.zeros(dim))
@@ -1371,13 +1468,11 @@ class SeeDNorm(nn.Module):
1371
  x: torch.Tensor,
1372
  analysis: Optional[SeeDNormAnalysis] = None,
1373
  ) -> torch.Tensor:
1374
- x_for_dynamic = F.dropout(x, p=self.dropout_input)
1375
  rescale_factor = torch.tanh(
1376
- torch.sum(x_for_dynamic * self.beta, dim=-1, keepdim=True)
1377
  )
1378
  dynamic_scale = rescale_factor * self.alpha + self.gamma
1379
  x_normalized = self._rms_norm(x.float())
1380
- x_normalized = F.dropout(x_normalized, p=self.dropout_hidden)
1381
  output = (x_normalized * dynamic_scale.float()).type_as(x)
1382
  if analysis is not None:
1383
  analysis.rescale_factor = rescale_factor.detach()
@@ -1387,9 +1482,7 @@ class SeeDNorm(nn.Module):
1387
  return output
1388
 
1389
  def extra_repr(self) -> str:
1390
- return (f"dim={self.dim}, eps={self.eps}, "
1391
- f"dropout_input={self.dropout_input}, "
1392
- f"dropout_hidden={self.dropout_hidden}")
1393
 
1394
 
1395
  # ==================== ROTARY EMBEDDING ====================
@@ -2743,6 +2836,36 @@ class VersatileFFN(nn.Module):
2743
  - Width path load-balancing returns (output, aux_stats) for integration
2744
  with the existing NeoLLMForCausalLM aux-loss accumulation pattern.
2745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2746
  Reference:
2747
  Nie et al. (2026). "VersatileFFN: Achieving Parameter Efficiency in
2748
  LLMs via Adaptive Wide-and-Deep Reuse." arXiv:2512.14531.
@@ -2920,45 +3043,97 @@ class VersatileFFN(nn.Module):
2920
  depth_stack = torch.stack(depth_outputs, dim=-1) # [B,S,D,L]
2921
  x_depth = (depth_stack * depth_probs.unsqueeze(2)).sum(dim=-1) # [B,S,D]
2922
 
2923
- # ── Width path: Top-K routing over virtual experts ────────────────
2924
- routing_logits = self.expert_gate(x) # [B,S,N]
2925
- topk_w, topk_i = torch.topk(routing_logits, k=self.active_experts, dim=-1)
2926
- topk_w = torch.softmax(topk_w, dim=-1) # [B,S,k]
2927
-
2928
- x_flat = x.reshape(-1, D) # [N,D]
2929
- x_fan_flat = x_fan.reshape(-1, x_fan.shape[-1]) # [N,fan]
2930
- topk_i_f = topk_i.reshape(-1, self.active_experts) # [N,k]
2931
- topk_w_f = topk_w.reshape(-1, self.active_experts) # [N,k]
2932
- N_tok = x_flat.shape[0]
2933
-
2934
- x_moe_flat = torch.zeros_like(x_flat)
2935
-
2936
- for eid in range(self.total_experts):
2937
- mask = (topk_i_f == eid)
2938
- tok_idx, k_idx = torch.where(mask)
2939
- if tok_idx.numel() == 0:
2940
- continue
2941
- w_e = topk_w_f[tok_idx, k_idx].unsqueeze(-1)
2942
- out_e = self._expert_forward(
2943
- x_fan_flat[tok_idx], x_flat[tok_idx], self.expert_idx[eid]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2944
  )
2945
- x_moe_flat.index_add_(
2946
- 0, tok_idx, (out_e * w_e).to(x_moe_flat.dtype)
 
 
 
 
 
2947
  )
2948
 
2949
  x_moe = x_moe_flat.reshape(B, S, D)
2950
 
2951
- # Load-balancing aux stats (same pattern as JTok-M)
 
 
2952
  r_probs_flat = torch.softmax(
2953
- routing_logits.reshape(-1, self.total_experts), dim=-1
2954
- ) # [N_tok, N_experts]
2955
- p_sum = r_probs_flat.sum(dim=0) # [N_experts]
2956
- f_counts = torch.zeros(
2957
- self.total_experts, device=x.device, dtype=x.dtype
2958
- )
2959
- for eid in range(self.total_experts):
2960
- f_counts[eid] = (topk_i_f == eid).float().sum()
2961
- f_sum = f_counts / (N_tok * self.active_experts) # [N_experts]
2962
  aux_stats = (p_sum, f_sum, N_tok)
2963
 
2964
  # ── Difficulty-aware fusion (Eq. 12–13) ──────────────────────────
@@ -2976,16 +3151,27 @@ class VersatileFFN(nn.Module):
2976
  # ═════════════════════ INFERENCE ══════════════════════════════════════
2977
  else:
2978
  loop_choice = depth_logits.argmax(dim=-1) # [B, S]
2979
- max_loop = int(loop_choice.max().item())
2980
 
2981
- # Depth path: early exit only compute needed iterations
 
 
 
 
 
 
 
 
 
 
 
 
2982
  depth_outputs = []
2983
  current_x = x
2984
- for _ in range(max_loop + 1):
2985
  current_x = self._full_forward_step(current_x)
2986
  depth_outputs.append(current_x)
2987
 
2988
- depth_stack = torch.stack(depth_outputs, dim=-1) # [B,S,D,run]
2989
  gather_idx = (
2990
  loop_choice.unsqueeze(-1).unsqueeze(-1).expand(B, S, D, 1)
2991
  )
@@ -2995,40 +3181,63 @@ class VersatileFFN(nn.Module):
2995
  expected_L = (loop_choice + 1).float() # [B, S]
2996
  moe_weight = (self.max_depth - expected_L) / self.max_depth # [B, S]
2997
 
2998
- # Width path: conditional on λ > 0 (Conditional Parallelism)
2999
- active_mask = (moe_weight > 1e-6) # [B, S]
3000
- x_moe = torch.zeros_like(x)
3001
  aux_stats = None
3002
  depth_probs = None
3003
 
3004
- if active_mask.any():
3005
- x_flat_all = x.reshape(-1, D)
3006
- x_fan_flat_all = x_fan.reshape(-1, x_fan.shape[-1])
3007
- active_flat = active_mask.reshape(-1)
3008
- x_active = x_flat_all[active_flat]
3009
- x_fan_active = x_fan_flat_all[active_flat]
3010
-
3011
- r_log = self.expert_gate(x_active) # [Na, N]
3012
- tw, ti = torch.topk(r_log, k=self.active_experts, dim=-1)
3013
- tw = torch.softmax(tw, dim=-1)
3014
-
3015
- x_moe_active = torch.zeros_like(x_active)
3016
- for eid in range(self.total_experts):
3017
- mask_e = (ti == eid)
3018
- tok_idx, k_idx = torch.where(mask_e)
3019
- if tok_idx.numel() == 0:
3020
- continue
3021
- w_e = tw[tok_idx, k_idx].unsqueeze(-1)
3022
- out_e = self._expert_forward(
3023
- x_fan_active[tok_idx], x_active[tok_idx], self.expert_idx[eid]
3024
- )
3025
- x_moe_active.index_add_(
3026
- 0, tok_idx, (out_e * w_e).to(x_moe_active.dtype)
3027
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3028
 
3029
- x_moe_flat = x_moe.reshape(-1, D)
3030
- x_moe_flat[active_flat] = x_moe_active
3031
- x_moe = x_moe_flat.reshape(B, S, D)
3032
 
3033
  output = (
3034
  x_depth * (1.0 - moe_weight.unsqueeze(-1))
 
787
  x_all: [N, M, d_seed], values in [0, 1], all heads stacked.
788
  Returns:
789
  [N, M, d_seed, n_knots] float32.
790
+
791
+ NOTE: Este método se mantiene para compatibilidad con JTok-M y análisis.
792
+ El forward del generator ya NO lo usa — usa _compute_head en su lugar.
793
  """
794
  x32 = x_all.float()
795
  x_e = x32.unsqueeze(-1) # [N, M, d_seed, 1]
 
804
  torch.where(d < 1.5, 0.5 * (1.5 - d) ** 2, torch.zeros_like(d)),
805
  ) # [N, M, d_seed, n_knots] float32
806
 
807
+ def _compute_head(
808
+ self,
809
+ z: torch.Tensor,
810
+ m: int,
811
+ ) -> torch.Tensor:
812
+ """
813
+ Forward completo para el cabezal m del generator.
814
+
815
+ Reemplaza la materialización conjunta [N, M, d_seed, n_knots] del path
816
+ vectorizado. Cada llamada materializa solo [N, d_seed, n_knots] (1 cabezal),
817
+ reduciendo el pico de memoria de O(M·d_seed·n_knots) a O(d_seed·n_knots)
818
+ por cabezal.
819
+
820
+ Pipeline:
821
+ z [N, d_seed]
822
+ → Linear(head_proj_weight[m*d_seed:(m+1)*d_seed]) → [N, d_seed]
823
+ → ManualLayerNorm(weight[m], bias[m]) → [N, d_seed]
824
+ → sigmoid(x/2) → [N, d_seed] (coordenada en [0,1]^d_seed)
825
+ → B-spline KHRONOS con scale=head_scale[m] → [N, d_seed, n_knots]
826
+ → einsum con head_spline[m] → per_dim [N, d_seed, krank]
827
+ → sign-parity product (log-sum-exp) → modes [N, krank]
828
+ → Linear(head_out_weight[m]) → [N, hidden_size]
829
+
830
+ Por qué loop Python sobre M cabezales en lugar de vmap:
831
+ torch.vmap sobre cabezales con parámetros distintos requiere
832
+ functional_call y stack_module_state, lo que complica el acceso
833
+ a buffers (knot_grid, head_norm_eps) desde dentro del transform.
834
+ Un loop Python con M=8 fijo es unrolleado por TorchDynamo en una
835
+ secuencia estática de ops — exactamente como lo hace XLA/Flax en
836
+ la implementación original de Reza. El compilador ve 8 grafos
837
+ idénticos en estructura pero con parámetros distintos, y puede
838
+ fusionarlos u optimizarlos de forma independiente. Con chunk_size=1
839
+ en vmap el comportamiento sería análogo pero con mayor overhead de
840
+ instrumentación.
841
+
842
+ Args:
843
+ z: [N, d_seed] — codebook seed compartido (float del dtype del modelo).
844
+ m: índice del cabezal (0 ≤ m < num_modes), Python int estático.
845
+ Returns:
846
+ [N, hidden_size] — contribución de este cabezal al embedding final.
847
+ """
848
+ d = self.d_seed
849
+ nk = self.num_knots
850
+ kr = self.krank
851
+
852
+ # ── Proyección lineal para el cabezal m ──────────────────────────
853
+ # head_proj_weight [M*d_seed, d_seed] — los pesos del cabezal m
854
+ # son las filas [m*d_seed : (m+1)*d_seed].
855
+ proj_w = self.head_proj_weight[m * d : (m + 1) * d] # [d_seed, d_seed]
856
+ zh = F.linear(z.float(), proj_w) # [N, d_seed]
857
+
858
+ # ── LayerNorm manual por cabezal ──────────────────────────────────
859
+ # Equivalente a nn.LayerNorm(d_seed) con parámetros independientes
860
+ # head_norm_weight[m] y head_norm_bias[m].
861
+ mean = zh.mean(dim=-1, keepdim=True)
862
+ var = zh.var(dim=-1, keepdim=True, unbiased=False)
863
+ zh = (zh - mean) / (var + self.head_norm_eps).sqrt()
864
+ zh = zh * self.head_norm_weight[m] + self.head_norm_bias[m]
865
+
866
+ # ── Sigmoid(x/2) → coordenada latente en [0,1]^d_seed ────────────
867
+ zh = torch.sigmoid(zh / 2.0).clamp(0.0, 1.0) # [N, d_seed]
868
+
869
+ # ── B-spline KHRONOS para este cabezal ────────────────────────────
870
+ # head_scale[m]: [d_seed] — escala por dimensión para este cabezal.
871
+ # Materializa [N, d_seed, n_knots] en lugar de [N, M, d_seed, n_knots].
872
+ sc = self.head_scale[m].float().view(1, -1, 1) # [1, d_seed, 1]
873
+ x_e = zh.unsqueeze(-1) # [N, d_seed, 1]
874
+ grid = self.knot_grid.float().view(1, 1, -1) # [1, 1, n_knots]
875
+ dist = (x_e - grid).abs() * sc # [N, d_seed, n_knots]
876
+ B_m = torch.where(
877
+ dist < 0.5,
878
+ 0.75 - dist ** 2,
879
+ torch.where(dist < 1.5, 0.5 * (1.5 - dist) ** 2, torch.zeros_like(dist)),
880
+ ) # [N, d_seed, n_knots]
881
+
882
+ # ── KHRONOS tensor product para este cabezal ──────────────────────
883
+ # head_spline[m]: [d_seed, n_knots, krank]
884
+ # per_dim[n, d, k] = Σ_g B_m[n, d, g] * head_spline[m, d, g, k]
885
+ # Shape: [N, d_seed, krank] — pico máximo en este cabezal.
886
+ per_dim = torch.einsum(
887
+ "ndg,dgk->ndk",
888
+ B_m,
889
+ self.head_spline[m].float(),
890
+ ) # [N, d_seed, krank]
891
+
892
+ # Sign-parity log-product (KHRONOS): evita underflow multiplicando
893
+ # en log-space y recuperando el signo por paridad de negativos.
894
+ per_dim_abs = per_dim.abs() + 1e-9
895
+ log_mag = torch.log(per_dim_abs).sum(dim=1) # [N, krank]
896
+ num_neg = (per_dim < 0).long().sum(dim=1) # [N, krank]
897
+ prod_sign = 1.0 - 2.0 * (num_neg % 2).float() # [N, krank]
898
+ modes_m = prod_sign * torch.exp(log_mag) # [N, krank]
899
+
900
+ # ── Proyección de salida del cabezal ──────────────────────────────
901
+ # head_out_weight[m]: [krank, hidden_size]
902
+ # NOTA: NO usar F.linear aquí. F.linear(A, W) computa A @ W.T,
903
+ # esperando W con shape [out, in] = [hidden, krank]. Pero
904
+ # head_out_weight está almacenado como [krank, hidden] (igual que
905
+ # el einsum original "nmk,mkd->nd" que contrae sobre k sin transponer).
906
+ # La multiplicación correcta es modes_m @ W directamente:
907
+ # [N, krank] @ [krank, hidden] → [N, hidden]
908
+ out_m = (
909
+ modes_m.to(self.head_out_weight.dtype)
910
+ @ self.head_out_weight[m]
911
+ ) # [N, hidden_size]
912
+ return out_m
913
+
914
  def _khronos_all_heads(
915
  self,
916
  B_all: torch.Tensor,
 
1037
  analysis.z_tilde = z_tilde.detach()
1038
  analysis.B_vals = B_vals.detach()
1039
 
1040
+ # ── Per-head generator path (secuencial, un cabezal a la vez) ──────
1041
+ # ORIGINAL PROBLEM: el path vectorizado anterior procesaba los M
1042
+ # cabezales en paralelo con kernels fusionados:
1043
+ #
1044
+ # _bspline_basis_all_heads [N, M, d_seed, n_knots] ← TENSOR GIGANTE
1045
+ # _khronos_all_heads per_dim [N, M, d_seed, krank] ← AÚN MAYOR
1046
+ #
1047
+ # Con N=B*S=32768, M=8, d_seed=128, n_knots=32, krank=16:
1048
+ # [N,M,d_seed,n_knots] = 32768 × 8 × 128 × 32 × 4 bytes ≈ 512 MB
1049
+ # [N,M,d_seed,krank] = 32768 × 8 × 128 × 16 × 4 bytes ≈ 256 MB
1050
+ # Estos tensores viven simultáneamente en el pool de CUDAGraphs,
1051
+ # causando OOM en el backward cuando se suman las activaciones guardadas
1052
+ # de las 12 capas del decoder.
1053
+ #
1054
+ # SOLUCIÓN (equivalente a la impl. JAX de Reza):
1055
+ # Loop Python sobre M=8 cabezales (count fijo → TorchDynamo unrollea
1056
+ # en 8 secuencias de ops estáticas sin graph breaks).
1057
+ # Cada cabezal materializa como máximo [N, d_seed, krank] ≈ 32 MB.
1058
+ # La suma se acumula in-place el tensor del cabezal anterior puede
1059
+ # ser liberado por el allocator antes de procesar el siguiente.
1060
+ #
1061
+ # Por qué NO vmap(chunk_size=1):
1062
+ # vmap requiere que la función sea "pura" (sin acceso a self.*).
1063
+ # head_norm_eps, knot_grid y los parámetros indexados [m] se pasan
1064
+ # implícitamente a través del closure. Con vmap habría que
1065
+ # stack_module_state + functional_call, lo que añade overhead de
1066
+ # instrumentación sin beneficio real ya que el loop estático es
1067
+ # igualmente trazable por el compilador y produce el mismo grafo.
 
 
 
 
 
 
 
 
 
1068
 
1069
+ target_dtype = self.codebooks.dtype
1070
+ e = torch.zeros(N, self.hidden_size, device=token_ids.device, dtype=target_dtype)
1071
+ for m in range(self.num_modes):
1072
+ e = e + self._compute_head(z, m)
 
 
 
 
1073
 
1074
  # No W_res — confirmed absent in the authors' implementation
1075
  e = e.reshape(*orig_shape, self.hidden_size)
 
1439
 
1440
  class SeeDNorm(nn.Module):
1441
  """
1442
+ Self-Rescaled Dynamic Normalization.
1443
  SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
1444
+
1445
+ rescale_factor = tanh(x · β) ∈ (-1, 1) escalar por token
1446
+ dynamic_scale = rescale_factor · α + γ ∈ ℝ^dim
1447
+ output = dynamic_scale ⊙ RMSNorm(x)
1448
  """
1449
 
1450
  def __init__(
1451
  self,
1452
  dim: int,
1453
  eps: float = 1e-6,
 
 
1454
  ):
1455
  super().__init__()
1456
+ self.dim = dim
1457
+ self.eps = eps
 
 
1458
 
1459
  self.gamma = nn.Parameter(torch.ones(dim))
1460
  self.beta = nn.Parameter(torch.zeros(dim))
 
1468
  x: torch.Tensor,
1469
  analysis: Optional[SeeDNormAnalysis] = None,
1470
  ) -> torch.Tensor:
 
1471
  rescale_factor = torch.tanh(
1472
+ torch.sum(x * self.beta, dim=-1, keepdim=True)
1473
  )
1474
  dynamic_scale = rescale_factor * self.alpha + self.gamma
1475
  x_normalized = self._rms_norm(x.float())
 
1476
  output = (x_normalized * dynamic_scale.float()).type_as(x)
1477
  if analysis is not None:
1478
  analysis.rescale_factor = rescale_factor.detach()
 
1482
  return output
1483
 
1484
  def extra_repr(self) -> str:
1485
+ return f"dim={self.dim}, eps={self.eps}"
 
 
1486
 
1487
 
1488
  # ==================== ROTARY EMBEDDING ====================
 
2836
  - Width path load-balancing returns (output, aux_stats) for integration
2837
  with the existing NeoLLMForCausalLM aux-loss accumulation pattern.
2838
 
2839
+ Width dispatch (CUDAGraph-compatible sparse routing):
2840
+ El dispatch original del paper (torch.where + index_add_) es sparse y
2841
+ fiel al paper pero produce shapes dependientes de datos → incompatible
2842
+ con CUDAGraphs. La implementación usa argsort como dispatcher estático:
2843
+
2844
+ flat_expert [N·K] → argsort → perm [N·K] (shape siempre igual)
2845
+ sorted_tok [N·K] = flat_tok[perm] (índices de token originales)
2846
+ grouped_tok [E, C] = sorted_tok.view(E, C) (C = N·K // E, constante)
2847
+
2848
+ Propiedades clave:
2849
+ · argsort: output shape = input shape, siempre [N·K]. CUDAGraph ✓
2850
+ · C = N_tok·K // E es un entero Python conocido en compile-time.
2851
+ Con el aux loss manteniendo balance, cada experto recibe ≈ C slots.
2852
+ · scatter_add_ con index [C, D] de shape estático: CUDAGraph ✓
2853
+ (los VALORES del index cambian por batch, no el SHAPE).
2854
+ · FLOPs idénticos al original: cada experto procesa [C, D] = [N·K/E, D]
2855
+ tokens, no todos los N tokens. Con K=2, E=4: C = N/2 por experto.
2856
+
2857
+ Conditional Parallelism (inferencia, Algorithm 2):
2858
+ · Los tokens con λ=0 (argmax → max_depth) igualmente participan en el
2859
+ grouped buffer y su expert forward se computa (shapes estáticos).
2860
+ · Su contribución es cancelada por λ=0 en la fusión:
2861
+ output = x_depth·(1−λ) + x_moe·λ → x_depth si ��=0
2862
+ · Esto pierde el saving de FLOPs de los λ=0 tokens, pero la correctitud
2863
+ matemática es exacta. Tradeoff aceptable vs CUDAGraph-incompatibilidad.
2864
+
2865
+ Discrete Early-Exit (inferencia, Algorithm 2):
2866
+ · Sustituido por always-max_depth + torch.gather con loop_choice.
2867
+ Para max_depth=2 el overhead es ≤ 1 iteración extra por token.
2868
+
2869
  Reference:
2870
  Nie et al. (2026). "VersatileFFN: Achieving Parameter Efficiency in
2871
  LLMs via Adaptive Wide-and-Deep Reuse." arXiv:2512.14531.
 
3043
  depth_stack = torch.stack(depth_outputs, dim=-1) # [B,S,D,L]
3044
  x_depth = (depth_stack * depth_probs.unsqueeze(2)).sum(dim=-1) # [B,S,D]
3045
 
3046
+ # ── Width path: argsort-based sparse dispatch (Eq. 7–8) ──────────
3047
+ # Matemática (paper §3.2):
3048
+ # Y_width = Σ_{k∈TopK} g_k · Y_k,
3049
+ # Y_k = H + W_out^(k) φ(W_proj^(k) LayerNorm(H)) (Eq. 8)
3050
+ # Como Σ_{k∈TopK} g_k = 1 (softmax normalizado sobre TopK):
3051
+ # Y_width = H + Σ_{k∈TopK} g_k · delta_k
3052
+ #
3053
+ # Implementación sparse con shapes estáticos:
3054
+ #
3055
+ # 1. flat_expert [N_tok·K]: índices de experto por token-slot.
3056
+ # argsort → perm [N_tok·K] con shape siempre igual. CUDAGraph ✓
3057
+ #
3058
+ # 2. sorted_tok [N_tok·K] = flat_tok[perm]: tokens ordenados por
3059
+ # experto. Todos los tokens del experto e quedan contiguos.
3060
+ #
3061
+ # 3. view(E, C) con C = N_tok·K // E constante Python → shape
3062
+ # estático [E, C, D] para gather y forward.
3063
+ #
3064
+ # 4. _expert_forward sobre [C, D] por experto — mismos FLOPs que
3065
+ # el original con torch.where: solo C tokens por experto,
3066
+ # no los N_tok completos. Con K=2, E=4: C = N_tok/2.
3067
+ #
3068
+ # 5. scatter_add_: index de shape [C, D] siempre estático.
3069
+ # Los VALORES varían por batch, el SHAPE no. CUDAGraph ✓
3070
+ # Acumula Σ_{k} g_k · Y_k para cada token n mediante
3071
+ # sum sobre los K slots que apuntan a n.
3072
+ K = self.active_experts
3073
+ E = self.total_experts
3074
+ N_tok = B * S
3075
+ C = (N_tok * K) // E # tokens por experto — constante compile-time
3076
+
3077
+ routing_logits = self.expert_gate(x) # [B, S, E]
3078
+ topk_w, topk_i = torch.topk(routing_logits, k=K, dim=-1)
3079
+ topk_w = torch.softmax(topk_w, dim=-1) # [B, S, K]
3080
+
3081
+ x_flat = x.reshape(-1, D) # [N_tok, D]
3082
+ x_fan_flat = x_fan.reshape(-1, x_fan.shape[-1]) # [N_tok, fan_dim]
3083
+
3084
+ # Aplanar: cada token aparece K veces, una por experto seleccionado
3085
+ flat_expert = topk_i.reshape(-1) # [N_tok·K] long
3086
+ flat_tok = (
3087
+ torch.arange(N_tok, device=x.device, dtype=torch.long)
3088
+ .unsqueeze(1).expand(N_tok, K).reshape(-1)
3089
+ ) # [N_tok·K] long
3090
+ flat_w = topk_w.reshape(-1) # [N_tok·K]
3091
+
3092
+ # Ordenar por expert ID: todos los tokens del mismo experto juntos
3093
+ perm = torch.argsort(flat_expert, stable=True) # [N_tok·K] long
3094
+ sorted_tok = flat_tok[perm] # [N_tok·K] long
3095
+ sorted_w = flat_w[perm] # [N_tok·K]
3096
+
3097
+ # Agrupar por experto [E, C] — C conocido en compile-time
3098
+ grouped_tok = sorted_tok.view(E, C) # [E, C] long
3099
+ grouped_w = sorted_w.view(E, C) # [E, C]
3100
+
3101
+ # Gather features del token original para cada slot de experto
3102
+ flat_idx = grouped_tok.reshape(-1) # [E·C] long
3103
+ fan_dim = x_fan_flat.shape[-1]
3104
+ x_grouped = x_flat[flat_idx].view(E, C, D) # [E, C, D]
3105
+ xf_grouped = x_fan_flat[flat_idx].view(E, C, fan_dim) # [E, C, fan_dim]
3106
+
3107
+ # Expert forward + scatter_add_ de vuelta a [N_tok, D]
3108
+ # Loop desenrollado por dynamo (E constante Python) — sin graph breaks
3109
+ x_moe_flat = torch.zeros(N_tok, D, device=x.device, dtype=x.dtype)
3110
+ for eid in range(E):
3111
+ # out_e [C, D] = x_grouped[eid] + delta_e (residual incluido, Eq. 8)
3112
+ out_e = self._expert_forward(
3113
+ xf_grouped[eid], x_grouped[eid], self.expert_idx[eid]
3114
  )
3115
+ w_e = grouped_w[eid].unsqueeze(-1) # [C, 1]
3116
+ tok_idx_e = grouped_tok[eid].unsqueeze(1).expand(C, D) # [C, D] long
3117
+ # Acumula g_k · Y_k en la posición original del token
3118
+ # Cuando eid recorre los K experts de un token n:
3119
+ # x_moe_flat[n] = Σ_k g_k · Y_k = H_n + Σ_k g_k · delta_k
3120
+ x_moe_flat.scatter_add_(
3121
+ 0, tok_idx_e, (out_e * w_e).to(x_moe_flat.dtype)
3122
  )
3123
 
3124
  x_moe = x_moe_flat.reshape(B, S, D)
3125
 
3126
+ # Load-balancing aux stats (Eq. load-balancing loss)
3127
+ # p_sum: probabilidad media por experto (sobre routing_logits completo)
3128
+ # f_sum: fracción real de tokens asignados a cada experto
3129
  r_probs_flat = torch.softmax(
3130
+ routing_logits.reshape(-1, E), dim=-1
3131
+ ) # [N_tok, E]
3132
+ p_sum = r_probs_flat.sum(dim=0) # [E]
3133
+ f_sum = (
3134
+ F.one_hot(flat_expert.long(), E).float().sum(dim=0)
3135
+ / float(N_tok * K)
3136
+ ) # [E]
 
 
3137
  aux_stats = (p_sum, f_sum, N_tok)
3138
 
3139
  # ── Difficulty-aware fusion (Eq. 12–13) ──────────────────────────
 
3151
  # ═════════════════════ INFERENCE ══════════════════════════════════════
3152
  else:
3153
  loop_choice = depth_logits.argmax(dim=-1) # [B, S]
 
3154
 
3155
+ # ── Depth path: siempre max_depth iteraciones (shape estático)
3156
+ # ORIGINAL PROBLEM: el early-exit original usaba
3157
+ # max_loop = int(loop_choice.max().item())
3158
+ # que produce una sincronización CPU-GPU (equivalente a .item())
3159
+ # y hace que el número de iteraciones del loop dependa de datos —
3160
+ # ambas condiciones prohíben la captura de CUDAGraphs.
3161
+ #
3162
+ # SOLUCIÓN: siempre se ejecutan exactamente self.max_depth
3163
+ # iteraciones. depth_stack [B,S,D,max_depth] tiene shape estático.
3164
+ # El gather sobre loop_choice selecciona la salida correcta por
3165
+ # token sin necesidad de conocer cuántas iteraciones se ejecutaron.
3166
+ # La pérdida de FLOPs por iteraciones "extra" es mínima porque
3167
+ # max_depth es pequeño (default 2) y _full_forward_step es ligero.
3168
  depth_outputs = []
3169
  current_x = x
3170
+ for _ in range(self.max_depth):
3171
  current_x = self._full_forward_step(current_x)
3172
  depth_outputs.append(current_x)
3173
 
3174
+ depth_stack = torch.stack(depth_outputs, dim=-1) # [B,S,D,max_depth]
3175
  gather_idx = (
3176
  loop_choice.unsqueeze(-1).unsqueeze(-1).expand(B, S, D, 1)
3177
  )
 
3181
  expected_L = (loop_choice + 1).float() # [B, S]
3182
  moe_weight = (self.max_depth - expected_L) / self.max_depth # [B, S]
3183
 
 
 
 
3184
  aux_stats = None
3185
  depth_probs = None
3186
 
3187
+ # ── Width path: argsort-based sparse dispatch (mismo mecanismo
3188
+ # que entrenamiento, Eq. 7–8 + Conditional Parallelism §A) ───
3189
+ #
3190
+ # Conditional Parallelism (Algorithm 2 del paper):
3191
+ # Si λ=0 para un token → Y = Y_depth, el width path se omite.
3192
+ # Con shapes estáticos no podemos excluir dinámicamente esos tokens
3193
+ # del buffer. En su lugar, los λ=0 tokens participan en el grouped
3194
+ # buffer y su expert forward corre, pero la fusión
3195
+ # output = x_depth·(1−λ) + x_moe·λ
3196
+ # garantiza output = x_depth cuando λ=0, sin ninguna rama condicional.
3197
+ # Los FLOPs del width path para esos tokens son el único overhead.
3198
+ N_tok_inf = B * S
3199
+ K_inf = self.active_experts
3200
+ E_inf = self.total_experts
3201
+ C_inf = (N_tok_inf * K_inf) // E_inf
3202
+
3203
+ x_flat = x.reshape(-1, D)
3204
+ x_fan_flat = x_fan.reshape(-1, x_fan.shape[-1])
3205
+
3206
+ routing_logits = self.expert_gate(x_flat) # [N_tok, E]
3207
+ tw, ti = torch.topk(routing_logits, k=K_inf, dim=-1)
3208
+ tw = torch.softmax(tw, dim=-1) # [N_tok, K]
3209
+
3210
+ flat_expert_i = ti.reshape(-1) # [N_tok·K] long
3211
+ flat_tok_i = (
3212
+ torch.arange(N_tok_inf, device=x.device, dtype=torch.long)
3213
+ .unsqueeze(1).expand(N_tok_inf, K_inf).reshape(-1)
3214
+ ) # [N_tok·K] long
3215
+ flat_w_i = tw.reshape(-1) # [N_tok·K]
3216
+
3217
+ perm_i = torch.argsort(flat_expert_i, stable=True) # [N_tok·K]
3218
+ sorted_tok_i = flat_tok_i[perm_i] # [N_tok·K]
3219
+ sorted_w_i = flat_w_i[perm_i] # [N_tok·K]
3220
+
3221
+ grouped_tok_i = sorted_tok_i.view(E_inf, C_inf) # [E, C]
3222
+ grouped_w_i = sorted_w_i.view(E_inf, C_inf) # [E, C]
3223
+
3224
+ flat_idx_i = grouped_tok_i.reshape(-1) # [E·C]
3225
+ fan_dim_i = x_fan_flat.shape[-1]
3226
+ x_grouped_i = x_flat[flat_idx_i].view(E_inf, C_inf, D) # [E, C, D]
3227
+ xf_grouped_i = x_fan_flat[flat_idx_i].view(E_inf, C_inf, fan_dim_i) # [E, C, fan_dim]
3228
+
3229
+ x_moe_flat_i = torch.zeros(N_tok_inf, D, device=x.device, dtype=x.dtype)
3230
+ for eid in range(E_inf):
3231
+ out_e_i = self._expert_forward(
3232
+ xf_grouped_i[eid], x_grouped_i[eid], self.expert_idx[eid]
3233
+ )
3234
+ w_e_i = grouped_w_i[eid].unsqueeze(-1) # [C, 1]
3235
+ tok_idx_e_i = grouped_tok_i[eid].unsqueeze(1).expand(C_inf, D)
3236
+ x_moe_flat_i.scatter_add_(
3237
+ 0, tok_idx_e_i, (out_e_i * w_e_i).to(x_moe_flat_i.dtype)
3238
+ )
3239
 
3240
+ x_moe = x_moe_flat_i.reshape(B, S, D)
 
 
3241
 
3242
  output = (
3243
  x_depth * (1.0 - moe_weight.unsqueeze(-1))