Update modeling_neollm.py
Browse files- modeling_neollm.py +334 -125
modeling_neollm.py
CHANGED
|
@@ -787,6 +787,9 @@ class LeviathanGenerator(nn.Module):
|
|
| 787 |
x_all: [N, M, d_seed], values in [0, 1], all heads stacked.
|
| 788 |
Returns:
|
| 789 |
[N, M, d_seed, n_knots] float32.
|
|
|
|
|
|
|
|
|
|
| 790 |
"""
|
| 791 |
x32 = x_all.float()
|
| 792 |
x_e = x32.unsqueeze(-1) # [N, M, d_seed, 1]
|
|
@@ -801,6 +804,113 @@ class LeviathanGenerator(nn.Module):
|
|
| 801 |
torch.where(d < 1.5, 0.5 * (1.5 - d) ** 2, torch.zeros_like(d)),
|
| 802 |
) # [N, M, d_seed, n_knots] float32
|
| 803 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
def _khronos_all_heads(
|
| 805 |
self,
|
| 806 |
B_all: torch.Tensor,
|
|
@@ -927,52 +1037,39 @@ class LeviathanGenerator(nn.Module):
|
|
| 927 |
analysis.z_tilde = z_tilde.detach()
|
| 928 |
analysis.B_vals = B_vals.detach()
|
| 929 |
|
| 930 |
-
# ── Per-head generator path (
|
| 931 |
-
#
|
| 932 |
-
#
|
| 933 |
-
|
| 934 |
-
#
|
| 935 |
-
#
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
#
|
| 943 |
-
#
|
| 944 |
-
#
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
#
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
#
|
| 957 |
-
#
|
| 958 |
-
B_all = self._bspline_basis_all_heads(
|
| 959 |
-
z_all.clamp(0.0, 1.0)
|
| 960 |
-
) # [N, M, d_seed, n_knots]
|
| 961 |
-
|
| 962 |
-
# Kernel 4: vectorized KHRONOS tensor product for all heads
|
| 963 |
-
modes_all = self._khronos_all_heads(B_all) # [N, M, krank]
|
| 964 |
-
|
| 965 |
-
if analysis is not None:
|
| 966 |
-
analysis.modes_all = modes_all.detach()
|
| 967 |
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
"nmk,mkd->nd",
|
| 973 |
-
modes_all.to(target_dtype),
|
| 974 |
-
self.head_out_weight.to(target_dtype),
|
| 975 |
-
) # [N, hidden_size]
|
| 976 |
|
| 977 |
# No W_res — confirmed absent in the authors' implementation
|
| 978 |
e = e.reshape(*orig_shape, self.hidden_size)
|
|
@@ -1342,22 +1439,22 @@ class GPAS(nn.Module):
|
|
| 1342 |
|
| 1343 |
class SeeDNorm(nn.Module):
|
| 1344 |
"""
|
| 1345 |
-
Self-Rescaled Dynamic Normalization
|
| 1346 |
SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1347 |
"""
|
| 1348 |
|
| 1349 |
def __init__(
|
| 1350 |
self,
|
| 1351 |
dim: int,
|
| 1352 |
eps: float = 1e-6,
|
| 1353 |
-
dropout_input: float = 0.01,
|
| 1354 |
-
dropout_hidden: float = 0.01,
|
| 1355 |
):
|
| 1356 |
super().__init__()
|
| 1357 |
-
self.dim
|
| 1358 |
-
self.eps
|
| 1359 |
-
self.dropout_input = dropout_input
|
| 1360 |
-
self.dropout_hidden = dropout_hidden
|
| 1361 |
|
| 1362 |
self.gamma = nn.Parameter(torch.ones(dim))
|
| 1363 |
self.beta = nn.Parameter(torch.zeros(dim))
|
|
@@ -1371,13 +1468,11 @@ class SeeDNorm(nn.Module):
|
|
| 1371 |
x: torch.Tensor,
|
| 1372 |
analysis: Optional[SeeDNormAnalysis] = None,
|
| 1373 |
) -> torch.Tensor:
|
| 1374 |
-
x_for_dynamic = F.dropout(x, p=self.dropout_input)
|
| 1375 |
rescale_factor = torch.tanh(
|
| 1376 |
-
torch.sum(
|
| 1377 |
)
|
| 1378 |
dynamic_scale = rescale_factor * self.alpha + self.gamma
|
| 1379 |
x_normalized = self._rms_norm(x.float())
|
| 1380 |
-
x_normalized = F.dropout(x_normalized, p=self.dropout_hidden)
|
| 1381 |
output = (x_normalized * dynamic_scale.float()).type_as(x)
|
| 1382 |
if analysis is not None:
|
| 1383 |
analysis.rescale_factor = rescale_factor.detach()
|
|
@@ -1387,9 +1482,7 @@ class SeeDNorm(nn.Module):
|
|
| 1387 |
return output
|
| 1388 |
|
| 1389 |
def extra_repr(self) -> str:
|
| 1390 |
-
return
|
| 1391 |
-
f"dropout_input={self.dropout_input}, "
|
| 1392 |
-
f"dropout_hidden={self.dropout_hidden}")
|
| 1393 |
|
| 1394 |
|
| 1395 |
# ==================== ROTARY EMBEDDING ====================
|
|
@@ -2743,6 +2836,36 @@ class VersatileFFN(nn.Module):
|
|
| 2743 |
- Width path load-balancing returns (output, aux_stats) for integration
|
| 2744 |
with the existing NeoLLMForCausalLM aux-loss accumulation pattern.
|
| 2745 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2746 |
Reference:
|
| 2747 |
Nie et al. (2026). "VersatileFFN: Achieving Parameter Efficiency in
|
| 2748 |
LLMs via Adaptive Wide-and-Deep Reuse." arXiv:2512.14531.
|
|
@@ -2920,45 +3043,97 @@ class VersatileFFN(nn.Module):
|
|
| 2920 |
depth_stack = torch.stack(depth_outputs, dim=-1) # [B,S,D,L]
|
| 2921 |
x_depth = (depth_stack * depth_probs.unsqueeze(2)).sum(dim=-1) # [B,S,D]
|
| 2922 |
|
| 2923 |
-
# ── Width path:
|
| 2924 |
-
|
| 2925 |
-
|
| 2926 |
-
|
| 2927 |
-
|
| 2928 |
-
|
| 2929 |
-
|
| 2930 |
-
|
| 2931 |
-
|
| 2932 |
-
|
| 2933 |
-
|
| 2934 |
-
|
| 2935 |
-
|
| 2936 |
-
|
| 2937 |
-
|
| 2938 |
-
|
| 2939 |
-
|
| 2940 |
-
|
| 2941 |
-
|
| 2942 |
-
|
| 2943 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2944 |
)
|
| 2945 |
-
|
| 2946 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2947 |
)
|
| 2948 |
|
| 2949 |
x_moe = x_moe_flat.reshape(B, S, D)
|
| 2950 |
|
| 2951 |
-
# Load-balancing aux stats (
|
|
|
|
|
|
|
| 2952 |
r_probs_flat = torch.softmax(
|
| 2953 |
-
routing_logits.reshape(-1,
|
| 2954 |
-
)
|
| 2955 |
-
p_sum = r_probs_flat.sum(dim=0)
|
| 2956 |
-
|
| 2957 |
-
|
| 2958 |
-
|
| 2959 |
-
|
| 2960 |
-
f_counts[eid] = (topk_i_f == eid).float().sum()
|
| 2961 |
-
f_sum = f_counts / (N_tok * self.active_experts) # [N_experts]
|
| 2962 |
aux_stats = (p_sum, f_sum, N_tok)
|
| 2963 |
|
| 2964 |
# ── Difficulty-aware fusion (Eq. 12–13) ──────────────────────────
|
|
@@ -2976,16 +3151,27 @@ class VersatileFFN(nn.Module):
|
|
| 2976 |
# ═════════════════════ INFERENCE ══════════════════════════════════════
|
| 2977 |
else:
|
| 2978 |
loop_choice = depth_logits.argmax(dim=-1) # [B, S]
|
| 2979 |
-
max_loop = int(loop_choice.max().item())
|
| 2980 |
|
| 2981 |
-
# Depth path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2982 |
depth_outputs = []
|
| 2983 |
current_x = x
|
| 2984 |
-
for _ in range(
|
| 2985 |
current_x = self._full_forward_step(current_x)
|
| 2986 |
depth_outputs.append(current_x)
|
| 2987 |
|
| 2988 |
-
depth_stack = torch.stack(depth_outputs, dim=-1) # [B,S,D,
|
| 2989 |
gather_idx = (
|
| 2990 |
loop_choice.unsqueeze(-1).unsqueeze(-1).expand(B, S, D, 1)
|
| 2991 |
)
|
|
@@ -2995,40 +3181,63 @@ class VersatileFFN(nn.Module):
|
|
| 2995 |
expected_L = (loop_choice + 1).float() # [B, S]
|
| 2996 |
moe_weight = (self.max_depth - expected_L) / self.max_depth # [B, S]
|
| 2997 |
|
| 2998 |
-
# Width path: conditional on λ > 0 (Conditional Parallelism)
|
| 2999 |
-
active_mask = (moe_weight > 1e-6) # [B, S]
|
| 3000 |
-
x_moe = torch.zeros_like(x)
|
| 3001 |
aux_stats = None
|
| 3002 |
depth_probs = None
|
| 3003 |
|
| 3004 |
-
|
| 3005 |
-
|
| 3006 |
-
|
| 3007 |
-
|
| 3008 |
-
|
| 3009 |
-
|
| 3010 |
-
|
| 3011 |
-
|
| 3012 |
-
|
| 3013 |
-
|
| 3014 |
-
|
| 3015 |
-
|
| 3016 |
-
|
| 3017 |
-
|
| 3018 |
-
|
| 3019 |
-
|
| 3020 |
-
|
| 3021 |
-
|
| 3022 |
-
|
| 3023 |
-
|
| 3024 |
-
|
| 3025 |
-
|
| 3026 |
-
|
| 3027 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3028 |
|
| 3029 |
-
|
| 3030 |
-
x_moe_flat[active_flat] = x_moe_active
|
| 3031 |
-
x_moe = x_moe_flat.reshape(B, S, D)
|
| 3032 |
|
| 3033 |
output = (
|
| 3034 |
x_depth * (1.0 - moe_weight.unsqueeze(-1))
|
|
|
|
| 787 |
x_all: [N, M, d_seed], values in [0, 1], all heads stacked.
|
| 788 |
Returns:
|
| 789 |
[N, M, d_seed, n_knots] float32.
|
| 790 |
+
|
| 791 |
+
NOTE: Este método se mantiene para compatibilidad con JTok-M y análisis.
|
| 792 |
+
El forward del generator ya NO lo usa — usa _compute_head en su lugar.
|
| 793 |
"""
|
| 794 |
x32 = x_all.float()
|
| 795 |
x_e = x32.unsqueeze(-1) # [N, M, d_seed, 1]
|
|
|
|
| 804 |
torch.where(d < 1.5, 0.5 * (1.5 - d) ** 2, torch.zeros_like(d)),
|
| 805 |
) # [N, M, d_seed, n_knots] float32
|
| 806 |
|
| 807 |
+
def _compute_head(
|
| 808 |
+
self,
|
| 809 |
+
z: torch.Tensor,
|
| 810 |
+
m: int,
|
| 811 |
+
) -> torch.Tensor:
|
| 812 |
+
"""
|
| 813 |
+
Forward completo para el cabezal m del generator.
|
| 814 |
+
|
| 815 |
+
Reemplaza la materialización conjunta [N, M, d_seed, n_knots] del path
|
| 816 |
+
vectorizado. Cada llamada materializa solo [N, d_seed, n_knots] (1 cabezal),
|
| 817 |
+
reduciendo el pico de memoria de O(M·d_seed·n_knots) a O(d_seed·n_knots)
|
| 818 |
+
por cabezal.
|
| 819 |
+
|
| 820 |
+
Pipeline:
|
| 821 |
+
z [N, d_seed]
|
| 822 |
+
→ Linear(head_proj_weight[m*d_seed:(m+1)*d_seed]) → [N, d_seed]
|
| 823 |
+
→ ManualLayerNorm(weight[m], bias[m]) → [N, d_seed]
|
| 824 |
+
→ sigmoid(x/2) → [N, d_seed] (coordenada en [0,1]^d_seed)
|
| 825 |
+
→ B-spline KHRONOS con scale=head_scale[m] → [N, d_seed, n_knots]
|
| 826 |
+
→ einsum con head_spline[m] → per_dim [N, d_seed, krank]
|
| 827 |
+
→ sign-parity product (log-sum-exp) → modes [N, krank]
|
| 828 |
+
→ Linear(head_out_weight[m]) → [N, hidden_size]
|
| 829 |
+
|
| 830 |
+
Por qué loop Python sobre M cabezales en lugar de vmap:
|
| 831 |
+
torch.vmap sobre cabezales con parámetros distintos requiere
|
| 832 |
+
functional_call y stack_module_state, lo que complica el acceso
|
| 833 |
+
a buffers (knot_grid, head_norm_eps) desde dentro del transform.
|
| 834 |
+
Un loop Python con M=8 fijo es unrolleado por TorchDynamo en una
|
| 835 |
+
secuencia estática de ops — exactamente como lo hace XLA/Flax en
|
| 836 |
+
la implementación original de Reza. El compilador ve 8 grafos
|
| 837 |
+
idénticos en estructura pero con parámetros distintos, y puede
|
| 838 |
+
fusionarlos u optimizarlos de forma independiente. Con chunk_size=1
|
| 839 |
+
en vmap el comportamiento sería análogo pero con mayor overhead de
|
| 840 |
+
instrumentación.
|
| 841 |
+
|
| 842 |
+
Args:
|
| 843 |
+
z: [N, d_seed] — codebook seed compartido (float del dtype del modelo).
|
| 844 |
+
m: índice del cabezal (0 ≤ m < num_modes), Python int estático.
|
| 845 |
+
Returns:
|
| 846 |
+
[N, hidden_size] — contribución de este cabezal al embedding final.
|
| 847 |
+
"""
|
| 848 |
+
d = self.d_seed
|
| 849 |
+
nk = self.num_knots
|
| 850 |
+
kr = self.krank
|
| 851 |
+
|
| 852 |
+
# ── Proyección lineal para el cabezal m ──────────────────────────
|
| 853 |
+
# head_proj_weight [M*d_seed, d_seed] — los pesos del cabezal m
|
| 854 |
+
# son las filas [m*d_seed : (m+1)*d_seed].
|
| 855 |
+
proj_w = self.head_proj_weight[m * d : (m + 1) * d] # [d_seed, d_seed]
|
| 856 |
+
zh = F.linear(z.float(), proj_w) # [N, d_seed]
|
| 857 |
+
|
| 858 |
+
# ── LayerNorm manual por cabezal ──────────────────────────────────
|
| 859 |
+
# Equivalente a nn.LayerNorm(d_seed) con parámetros independientes
|
| 860 |
+
# head_norm_weight[m] y head_norm_bias[m].
|
| 861 |
+
mean = zh.mean(dim=-1, keepdim=True)
|
| 862 |
+
var = zh.var(dim=-1, keepdim=True, unbiased=False)
|
| 863 |
+
zh = (zh - mean) / (var + self.head_norm_eps).sqrt()
|
| 864 |
+
zh = zh * self.head_norm_weight[m] + self.head_norm_bias[m]
|
| 865 |
+
|
| 866 |
+
# ── Sigmoid(x/2) → coordenada latente en [0,1]^d_seed ────────────
|
| 867 |
+
zh = torch.sigmoid(zh / 2.0).clamp(0.0, 1.0) # [N, d_seed]
|
| 868 |
+
|
| 869 |
+
# ── B-spline KHRONOS para este cabezal ────────────────────────────
|
| 870 |
+
# head_scale[m]: [d_seed] — escala por dimensión para este cabezal.
|
| 871 |
+
# Materializa [N, d_seed, n_knots] en lugar de [N, M, d_seed, n_knots].
|
| 872 |
+
sc = self.head_scale[m].float().view(1, -1, 1) # [1, d_seed, 1]
|
| 873 |
+
x_e = zh.unsqueeze(-1) # [N, d_seed, 1]
|
| 874 |
+
grid = self.knot_grid.float().view(1, 1, -1) # [1, 1, n_knots]
|
| 875 |
+
dist = (x_e - grid).abs() * sc # [N, d_seed, n_knots]
|
| 876 |
+
B_m = torch.where(
|
| 877 |
+
dist < 0.5,
|
| 878 |
+
0.75 - dist ** 2,
|
| 879 |
+
torch.where(dist < 1.5, 0.5 * (1.5 - dist) ** 2, torch.zeros_like(dist)),
|
| 880 |
+
) # [N, d_seed, n_knots]
|
| 881 |
+
|
| 882 |
+
# ── KHRONOS tensor product para este cabezal ──────────────────────
|
| 883 |
+
# head_spline[m]: [d_seed, n_knots, krank]
|
| 884 |
+
# per_dim[n, d, k] = Σ_g B_m[n, d, g] * head_spline[m, d, g, k]
|
| 885 |
+
# Shape: [N, d_seed, krank] — pico máximo en este cabezal.
|
| 886 |
+
per_dim = torch.einsum(
|
| 887 |
+
"ndg,dgk->ndk",
|
| 888 |
+
B_m,
|
| 889 |
+
self.head_spline[m].float(),
|
| 890 |
+
) # [N, d_seed, krank]
|
| 891 |
+
|
| 892 |
+
# Sign-parity log-product (KHRONOS): evita underflow multiplicando
|
| 893 |
+
# en log-space y recuperando el signo por paridad de negativos.
|
| 894 |
+
per_dim_abs = per_dim.abs() + 1e-9
|
| 895 |
+
log_mag = torch.log(per_dim_abs).sum(dim=1) # [N, krank]
|
| 896 |
+
num_neg = (per_dim < 0).long().sum(dim=1) # [N, krank]
|
| 897 |
+
prod_sign = 1.0 - 2.0 * (num_neg % 2).float() # [N, krank]
|
| 898 |
+
modes_m = prod_sign * torch.exp(log_mag) # [N, krank]
|
| 899 |
+
|
| 900 |
+
# ── Proyección de salida del cabezal ──────────────────────────────
|
| 901 |
+
# head_out_weight[m]: [krank, hidden_size]
|
| 902 |
+
# NOTA: NO usar F.linear aquí. F.linear(A, W) computa A @ W.T,
|
| 903 |
+
# esperando W con shape [out, in] = [hidden, krank]. Pero
|
| 904 |
+
# head_out_weight está almacenado como [krank, hidden] (igual que
|
| 905 |
+
# el einsum original "nmk,mkd->nd" que contrae sobre k sin transponer).
|
| 906 |
+
# La multiplicación correcta es modes_m @ W directamente:
|
| 907 |
+
# [N, krank] @ [krank, hidden] → [N, hidden]
|
| 908 |
+
out_m = (
|
| 909 |
+
modes_m.to(self.head_out_weight.dtype)
|
| 910 |
+
@ self.head_out_weight[m]
|
| 911 |
+
) # [N, hidden_size]
|
| 912 |
+
return out_m
|
| 913 |
+
|
| 914 |
def _khronos_all_heads(
|
| 915 |
self,
|
| 916 |
B_all: torch.Tensor,
|
|
|
|
| 1037 |
analysis.z_tilde = z_tilde.detach()
|
| 1038 |
analysis.B_vals = B_vals.detach()
|
| 1039 |
|
| 1040 |
+
# ── Per-head generator path (secuencial, un cabezal a la vez) ──────
|
| 1041 |
+
# ORIGINAL PROBLEM: el path vectorizado anterior procesaba los M
|
| 1042 |
+
# cabezales en paralelo con kernels fusionados:
|
| 1043 |
+
#
|
| 1044 |
+
# _bspline_basis_all_heads → [N, M, d_seed, n_knots] ← TENSOR GIGANTE
|
| 1045 |
+
# _khronos_all_heads → per_dim [N, M, d_seed, krank] ← AÚN MAYOR
|
| 1046 |
+
#
|
| 1047 |
+
# Con N=B*S=32768, M=8, d_seed=128, n_knots=32, krank=16:
|
| 1048 |
+
# [N,M,d_seed,n_knots] = 32768 × 8 × 128 × 32 × 4 bytes ≈ 512 MB
|
| 1049 |
+
# [N,M,d_seed,krank] = 32768 × 8 × 128 × 16 × 4 bytes ≈ 256 MB
|
| 1050 |
+
# Estos tensores viven simultáneamente en el pool de CUDAGraphs,
|
| 1051 |
+
# causando OOM en el backward cuando se suman las activaciones guardadas
|
| 1052 |
+
# de las 12 capas del decoder.
|
| 1053 |
+
#
|
| 1054 |
+
# SOLUCIÓN (equivalente a la impl. JAX de Reza):
|
| 1055 |
+
# Loop Python sobre M=8 cabezales (count fijo → TorchDynamo unrollea
|
| 1056 |
+
# en 8 secuencias de ops estáticas sin graph breaks).
|
| 1057 |
+
# Cada cabezal materializa como máximo [N, d_seed, krank] ≈ 32 MB.
|
| 1058 |
+
# La suma se acumula in-place → el tensor del cabezal anterior puede
|
| 1059 |
+
# ser liberado por el allocator antes de procesar el siguiente.
|
| 1060 |
+
#
|
| 1061 |
+
# Por qué NO vmap(chunk_size=1):
|
| 1062 |
+
# vmap requiere que la función sea "pura" (sin acceso a self.*).
|
| 1063 |
+
# head_norm_eps, knot_grid y los parámetros indexados [m] se pasan
|
| 1064 |
+
# implícitamente a través del closure. Con vmap habría que
|
| 1065 |
+
# stack_module_state + functional_call, lo que añade overhead de
|
| 1066 |
+
# instrumentación sin beneficio real ya que el loop estático es
|
| 1067 |
+
# igualmente trazable por el compilador y produce el mismo grafo.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1068 |
|
| 1069 |
+
target_dtype = self.codebooks.dtype
|
| 1070 |
+
e = torch.zeros(N, self.hidden_size, device=token_ids.device, dtype=target_dtype)
|
| 1071 |
+
for m in range(self.num_modes):
|
| 1072 |
+
e = e + self._compute_head(z, m)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1073 |
|
| 1074 |
# No W_res — confirmed absent in the authors' implementation
|
| 1075 |
e = e.reshape(*orig_shape, self.hidden_size)
|
|
|
|
| 1439 |
|
| 1440 |
class SeeDNorm(nn.Module):
|
| 1441 |
"""
|
| 1442 |
+
Self-Rescaled Dynamic Normalization.
|
| 1443 |
SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
|
| 1444 |
+
|
| 1445 |
+
rescale_factor = tanh(x · β) ∈ (-1, 1) escalar por token
|
| 1446 |
+
dynamic_scale = rescale_factor · α + γ ∈ ℝ^dim
|
| 1447 |
+
output = dynamic_scale ⊙ RMSNorm(x)
|
| 1448 |
"""
|
| 1449 |
|
| 1450 |
def __init__(
|
| 1451 |
self,
|
| 1452 |
dim: int,
|
| 1453 |
eps: float = 1e-6,
|
|
|
|
|
|
|
| 1454 |
):
|
| 1455 |
super().__init__()
|
| 1456 |
+
self.dim = dim
|
| 1457 |
+
self.eps = eps
|
|
|
|
|
|
|
| 1458 |
|
| 1459 |
self.gamma = nn.Parameter(torch.ones(dim))
|
| 1460 |
self.beta = nn.Parameter(torch.zeros(dim))
|
|
|
|
| 1468 |
x: torch.Tensor,
|
| 1469 |
analysis: Optional[SeeDNormAnalysis] = None,
|
| 1470 |
) -> torch.Tensor:
|
|
|
|
| 1471 |
rescale_factor = torch.tanh(
|
| 1472 |
+
torch.sum(x * self.beta, dim=-1, keepdim=True)
|
| 1473 |
)
|
| 1474 |
dynamic_scale = rescale_factor * self.alpha + self.gamma
|
| 1475 |
x_normalized = self._rms_norm(x.float())
|
|
|
|
| 1476 |
output = (x_normalized * dynamic_scale.float()).type_as(x)
|
| 1477 |
if analysis is not None:
|
| 1478 |
analysis.rescale_factor = rescale_factor.detach()
|
|
|
|
| 1482 |
return output
|
| 1483 |
|
| 1484 |
def extra_repr(self) -> str:
|
| 1485 |
+
return f"dim={self.dim}, eps={self.eps}"
|
|
|
|
|
|
|
| 1486 |
|
| 1487 |
|
| 1488 |
# ==================== ROTARY EMBEDDING ====================
|
|
|
|
| 2836 |
- Width path load-balancing returns (output, aux_stats) for integration
|
| 2837 |
with the existing NeoLLMForCausalLM aux-loss accumulation pattern.
|
| 2838 |
|
| 2839 |
+
Width dispatch (CUDAGraph-compatible sparse routing):
|
| 2840 |
+
El dispatch original del paper (torch.where + index_add_) es sparse y
|
| 2841 |
+
fiel al paper pero produce shapes dependientes de datos → incompatible
|
| 2842 |
+
con CUDAGraphs. La implementación usa argsort como dispatcher estático:
|
| 2843 |
+
|
| 2844 |
+
flat_expert [N·K] → argsort → perm [N·K] (shape siempre igual)
|
| 2845 |
+
sorted_tok [N·K] = flat_tok[perm] (índices de token originales)
|
| 2846 |
+
grouped_tok [E, C] = sorted_tok.view(E, C) (C = N·K // E, constante)
|
| 2847 |
+
|
| 2848 |
+
Propiedades clave:
|
| 2849 |
+
· argsort: output shape = input shape, siempre [N·K]. CUDAGraph ✓
|
| 2850 |
+
· C = N_tok·K // E es un entero Python conocido en compile-time.
|
| 2851 |
+
Con el aux loss manteniendo balance, cada experto recibe ≈ C slots.
|
| 2852 |
+
· scatter_add_ con index [C, D] de shape estático: CUDAGraph ✓
|
| 2853 |
+
(los VALORES del index cambian por batch, no el SHAPE).
|
| 2854 |
+
· FLOPs idénticos al original: cada experto procesa [C, D] = [N·K/E, D]
|
| 2855 |
+
tokens, no todos los N tokens. Con K=2, E=4: C = N/2 por experto.
|
| 2856 |
+
|
| 2857 |
+
Conditional Parallelism (inferencia, Algorithm 2):
|
| 2858 |
+
· Los tokens con λ=0 (argmax → max_depth) igualmente participan en el
|
| 2859 |
+
grouped buffer y su expert forward se computa (shapes estáticos).
|
| 2860 |
+
· Su contribución es cancelada por λ=0 en la fusión:
|
| 2861 |
+
output = x_depth·(1−λ) + x_moe·λ → x_depth si ��=0
|
| 2862 |
+
· Esto pierde el saving de FLOPs de los λ=0 tokens, pero la correctitud
|
| 2863 |
+
matemática es exacta. Tradeoff aceptable vs CUDAGraph-incompatibilidad.
|
| 2864 |
+
|
| 2865 |
+
Discrete Early-Exit (inferencia, Algorithm 2):
|
| 2866 |
+
· Sustituido por always-max_depth + torch.gather con loop_choice.
|
| 2867 |
+
Para max_depth=2 el overhead es ≤ 1 iteración extra por token.
|
| 2868 |
+
|
| 2869 |
Reference:
|
| 2870 |
Nie et al. (2026). "VersatileFFN: Achieving Parameter Efficiency in
|
| 2871 |
LLMs via Adaptive Wide-and-Deep Reuse." arXiv:2512.14531.
|
|
|
|
| 3043 |
depth_stack = torch.stack(depth_outputs, dim=-1) # [B,S,D,L]
|
| 3044 |
x_depth = (depth_stack * depth_probs.unsqueeze(2)).sum(dim=-1) # [B,S,D]
|
| 3045 |
|
| 3046 |
+
# ── Width path: argsort-based sparse dispatch (Eq. 7–8) ──────────
|
| 3047 |
+
# Matemática (paper §3.2):
|
| 3048 |
+
# Y_width = Σ_{k∈TopK} g_k · Y_k,
|
| 3049 |
+
# Y_k = H + W_out^(k) φ(W_proj^(k) LayerNorm(H)) (Eq. 8)
|
| 3050 |
+
# Como Σ_{k∈TopK} g_k = 1 (softmax normalizado sobre TopK):
|
| 3051 |
+
# Y_width = H + Σ_{k∈TopK} g_k · delta_k
|
| 3052 |
+
#
|
| 3053 |
+
# Implementación sparse con shapes estáticos:
|
| 3054 |
+
#
|
| 3055 |
+
# 1. flat_expert [N_tok·K]: índices de experto por token-slot.
|
| 3056 |
+
# argsort → perm [N_tok·K] con shape siempre igual. CUDAGraph ✓
|
| 3057 |
+
#
|
| 3058 |
+
# 2. sorted_tok [N_tok·K] = flat_tok[perm]: tokens ordenados por
|
| 3059 |
+
# experto. Todos los tokens del experto e quedan contiguos.
|
| 3060 |
+
#
|
| 3061 |
+
# 3. view(E, C) con C = N_tok·K // E constante Python → shape
|
| 3062 |
+
# estático [E, C, D] para gather y forward.
|
| 3063 |
+
#
|
| 3064 |
+
# 4. _expert_forward sobre [C, D] por experto — mismos FLOPs que
|
| 3065 |
+
# el original con torch.where: solo C tokens por experto,
|
| 3066 |
+
# no los N_tok completos. Con K=2, E=4: C = N_tok/2.
|
| 3067 |
+
#
|
| 3068 |
+
# 5. scatter_add_: index de shape [C, D] siempre estático.
|
| 3069 |
+
# Los VALORES varían por batch, el SHAPE no. CUDAGraph ✓
|
| 3070 |
+
# Acumula Σ_{k} g_k · Y_k para cada token n mediante
|
| 3071 |
+
# sum sobre los K slots que apuntan a n.
|
| 3072 |
+
K = self.active_experts
|
| 3073 |
+
E = self.total_experts
|
| 3074 |
+
N_tok = B * S
|
| 3075 |
+
C = (N_tok * K) // E # tokens por experto — constante compile-time
|
| 3076 |
+
|
| 3077 |
+
routing_logits = self.expert_gate(x) # [B, S, E]
|
| 3078 |
+
topk_w, topk_i = torch.topk(routing_logits, k=K, dim=-1)
|
| 3079 |
+
topk_w = torch.softmax(topk_w, dim=-1) # [B, S, K]
|
| 3080 |
+
|
| 3081 |
+
x_flat = x.reshape(-1, D) # [N_tok, D]
|
| 3082 |
+
x_fan_flat = x_fan.reshape(-1, x_fan.shape[-1]) # [N_tok, fan_dim]
|
| 3083 |
+
|
| 3084 |
+
# Aplanar: cada token aparece K veces, una por experto seleccionado
|
| 3085 |
+
flat_expert = topk_i.reshape(-1) # [N_tok·K] long
|
| 3086 |
+
flat_tok = (
|
| 3087 |
+
torch.arange(N_tok, device=x.device, dtype=torch.long)
|
| 3088 |
+
.unsqueeze(1).expand(N_tok, K).reshape(-1)
|
| 3089 |
+
) # [N_tok·K] long
|
| 3090 |
+
flat_w = topk_w.reshape(-1) # [N_tok·K]
|
| 3091 |
+
|
| 3092 |
+
# Ordenar por expert ID: todos los tokens del mismo experto juntos
|
| 3093 |
+
perm = torch.argsort(flat_expert, stable=True) # [N_tok·K] long
|
| 3094 |
+
sorted_tok = flat_tok[perm] # [N_tok·K] long
|
| 3095 |
+
sorted_w = flat_w[perm] # [N_tok·K]
|
| 3096 |
+
|
| 3097 |
+
# Agrupar por experto [E, C] — C conocido en compile-time
|
| 3098 |
+
grouped_tok = sorted_tok.view(E, C) # [E, C] long
|
| 3099 |
+
grouped_w = sorted_w.view(E, C) # [E, C]
|
| 3100 |
+
|
| 3101 |
+
# Gather features del token original para cada slot de experto
|
| 3102 |
+
flat_idx = grouped_tok.reshape(-1) # [E·C] long
|
| 3103 |
+
fan_dim = x_fan_flat.shape[-1]
|
| 3104 |
+
x_grouped = x_flat[flat_idx].view(E, C, D) # [E, C, D]
|
| 3105 |
+
xf_grouped = x_fan_flat[flat_idx].view(E, C, fan_dim) # [E, C, fan_dim]
|
| 3106 |
+
|
| 3107 |
+
# Expert forward + scatter_add_ de vuelta a [N_tok, D]
|
| 3108 |
+
# Loop desenrollado por dynamo (E constante Python) — sin graph breaks
|
| 3109 |
+
x_moe_flat = torch.zeros(N_tok, D, device=x.device, dtype=x.dtype)
|
| 3110 |
+
for eid in range(E):
|
| 3111 |
+
# out_e [C, D] = x_grouped[eid] + delta_e (residual incluido, Eq. 8)
|
| 3112 |
+
out_e = self._expert_forward(
|
| 3113 |
+
xf_grouped[eid], x_grouped[eid], self.expert_idx[eid]
|
| 3114 |
)
|
| 3115 |
+
w_e = grouped_w[eid].unsqueeze(-1) # [C, 1]
|
| 3116 |
+
tok_idx_e = grouped_tok[eid].unsqueeze(1).expand(C, D) # [C, D] long
|
| 3117 |
+
# Acumula g_k · Y_k en la posición original del token
|
| 3118 |
+
# Cuando eid recorre los K experts de un token n:
|
| 3119 |
+
# x_moe_flat[n] = Σ_k g_k · Y_k = H_n + Σ_k g_k · delta_k
|
| 3120 |
+
x_moe_flat.scatter_add_(
|
| 3121 |
+
0, tok_idx_e, (out_e * w_e).to(x_moe_flat.dtype)
|
| 3122 |
)
|
| 3123 |
|
| 3124 |
x_moe = x_moe_flat.reshape(B, S, D)
|
| 3125 |
|
| 3126 |
+
# Load-balancing aux stats (Eq. load-balancing loss)
|
| 3127 |
+
# p_sum: probabilidad media por experto (sobre routing_logits completo)
|
| 3128 |
+
# f_sum: fracción real de tokens asignados a cada experto
|
| 3129 |
r_probs_flat = torch.softmax(
|
| 3130 |
+
routing_logits.reshape(-1, E), dim=-1
|
| 3131 |
+
) # [N_tok, E]
|
| 3132 |
+
p_sum = r_probs_flat.sum(dim=0) # [E]
|
| 3133 |
+
f_sum = (
|
| 3134 |
+
F.one_hot(flat_expert.long(), E).float().sum(dim=0)
|
| 3135 |
+
/ float(N_tok * K)
|
| 3136 |
+
) # [E]
|
|
|
|
|
|
|
| 3137 |
aux_stats = (p_sum, f_sum, N_tok)
|
| 3138 |
|
| 3139 |
# ── Difficulty-aware fusion (Eq. 12–13) ──────────────────────────
|
|
|
|
| 3151 |
# ═════════════════════ INFERENCE ══════════════════════════════════════
|
| 3152 |
else:
|
| 3153 |
loop_choice = depth_logits.argmax(dim=-1) # [B, S]
|
|
|
|
| 3154 |
|
| 3155 |
+
# ── Depth path: siempre max_depth iteraciones (shape estático) ─
|
| 3156 |
+
# ORIGINAL PROBLEM: el early-exit original usaba
|
| 3157 |
+
# max_loop = int(loop_choice.max().item())
|
| 3158 |
+
# que produce una sincronización CPU-GPU (equivalente a .item())
|
| 3159 |
+
# y hace que el número de iteraciones del loop dependa de datos —
|
| 3160 |
+
# ambas condiciones prohíben la captura de CUDAGraphs.
|
| 3161 |
+
#
|
| 3162 |
+
# SOLUCIÓN: siempre se ejecutan exactamente self.max_depth
|
| 3163 |
+
# iteraciones. depth_stack [B,S,D,max_depth] tiene shape estático.
|
| 3164 |
+
# El gather sobre loop_choice selecciona la salida correcta por
|
| 3165 |
+
# token sin necesidad de conocer cuántas iteraciones se ejecutaron.
|
| 3166 |
+
# La pérdida de FLOPs por iteraciones "extra" es mínima porque
|
| 3167 |
+
# max_depth es pequeño (default 2) y _full_forward_step es ligero.
|
| 3168 |
depth_outputs = []
|
| 3169 |
current_x = x
|
| 3170 |
+
for _ in range(self.max_depth):
|
| 3171 |
current_x = self._full_forward_step(current_x)
|
| 3172 |
depth_outputs.append(current_x)
|
| 3173 |
|
| 3174 |
+
depth_stack = torch.stack(depth_outputs, dim=-1) # [B,S,D,max_depth]
|
| 3175 |
gather_idx = (
|
| 3176 |
loop_choice.unsqueeze(-1).unsqueeze(-1).expand(B, S, D, 1)
|
| 3177 |
)
|
|
|
|
| 3181 |
expected_L = (loop_choice + 1).float() # [B, S]
|
| 3182 |
moe_weight = (self.max_depth - expected_L) / self.max_depth # [B, S]
|
| 3183 |
|
|
|
|
|
|
|
|
|
|
| 3184 |
aux_stats = None
|
| 3185 |
depth_probs = None
|
| 3186 |
|
| 3187 |
+
# ── Width path: argsort-based sparse dispatch (mismo mecanismo
|
| 3188 |
+
# que entrenamiento, Eq. 7–8 + Conditional Parallelism §A) ───
|
| 3189 |
+
#
|
| 3190 |
+
# Conditional Parallelism (Algorithm 2 del paper):
|
| 3191 |
+
# Si λ=0 para un token → Y = Y_depth, el width path se omite.
|
| 3192 |
+
# Con shapes estáticos no podemos excluir dinámicamente esos tokens
|
| 3193 |
+
# del buffer. En su lugar, los λ=0 tokens participan en el grouped
|
| 3194 |
+
# buffer y su expert forward corre, pero la fusión
|
| 3195 |
+
# output = x_depth·(1−λ) + x_moe·λ
|
| 3196 |
+
# garantiza output = x_depth cuando λ=0, sin ninguna rama condicional.
|
| 3197 |
+
# Los FLOPs del width path para esos tokens son el único overhead.
|
| 3198 |
+
N_tok_inf = B * S
|
| 3199 |
+
K_inf = self.active_experts
|
| 3200 |
+
E_inf = self.total_experts
|
| 3201 |
+
C_inf = (N_tok_inf * K_inf) // E_inf
|
| 3202 |
+
|
| 3203 |
+
x_flat = x.reshape(-1, D)
|
| 3204 |
+
x_fan_flat = x_fan.reshape(-1, x_fan.shape[-1])
|
| 3205 |
+
|
| 3206 |
+
routing_logits = self.expert_gate(x_flat) # [N_tok, E]
|
| 3207 |
+
tw, ti = torch.topk(routing_logits, k=K_inf, dim=-1)
|
| 3208 |
+
tw = torch.softmax(tw, dim=-1) # [N_tok, K]
|
| 3209 |
+
|
| 3210 |
+
flat_expert_i = ti.reshape(-1) # [N_tok·K] long
|
| 3211 |
+
flat_tok_i = (
|
| 3212 |
+
torch.arange(N_tok_inf, device=x.device, dtype=torch.long)
|
| 3213 |
+
.unsqueeze(1).expand(N_tok_inf, K_inf).reshape(-1)
|
| 3214 |
+
) # [N_tok·K] long
|
| 3215 |
+
flat_w_i = tw.reshape(-1) # [N_tok·K]
|
| 3216 |
+
|
| 3217 |
+
perm_i = torch.argsort(flat_expert_i, stable=True) # [N_tok·K]
|
| 3218 |
+
sorted_tok_i = flat_tok_i[perm_i] # [N_tok·K]
|
| 3219 |
+
sorted_w_i = flat_w_i[perm_i] # [N_tok·K]
|
| 3220 |
+
|
| 3221 |
+
grouped_tok_i = sorted_tok_i.view(E_inf, C_inf) # [E, C]
|
| 3222 |
+
grouped_w_i = sorted_w_i.view(E_inf, C_inf) # [E, C]
|
| 3223 |
+
|
| 3224 |
+
flat_idx_i = grouped_tok_i.reshape(-1) # [E·C]
|
| 3225 |
+
fan_dim_i = x_fan_flat.shape[-1]
|
| 3226 |
+
x_grouped_i = x_flat[flat_idx_i].view(E_inf, C_inf, D) # [E, C, D]
|
| 3227 |
+
xf_grouped_i = x_fan_flat[flat_idx_i].view(E_inf, C_inf, fan_dim_i) # [E, C, fan_dim]
|
| 3228 |
+
|
| 3229 |
+
x_moe_flat_i = torch.zeros(N_tok_inf, D, device=x.device, dtype=x.dtype)
|
| 3230 |
+
for eid in range(E_inf):
|
| 3231 |
+
out_e_i = self._expert_forward(
|
| 3232 |
+
xf_grouped_i[eid], x_grouped_i[eid], self.expert_idx[eid]
|
| 3233 |
+
)
|
| 3234 |
+
w_e_i = grouped_w_i[eid].unsqueeze(-1) # [C, 1]
|
| 3235 |
+
tok_idx_e_i = grouped_tok_i[eid].unsqueeze(1).expand(C_inf, D)
|
| 3236 |
+
x_moe_flat_i.scatter_add_(
|
| 3237 |
+
0, tok_idx_e_i, (out_e_i * w_e_i).to(x_moe_flat_i.dtype)
|
| 3238 |
+
)
|
| 3239 |
|
| 3240 |
+
x_moe = x_moe_flat_i.reshape(B, S, D)
|
|
|
|
|
|
|
| 3241 |
|
| 3242 |
output = (
|
| 3243 |
x_depth * (1.0 - moe_weight.unsqueeze(-1))
|