diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index e4df5c5bca68f5679683843aab7107a65e08c36c..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cb6163428ce86500d61c2b765eecd7eb6f31c092066278e1d1af7a0848dc5126 -size 1940944 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..8f9c672be7ef1f016613e205dfc115f9af85d8e2 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:075fc73dbb2750aed7598cc3e13b593b6b1e7a78a78491e1b852fbd2a9af8f8d +size 1940944 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py b/build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/muon.py b/build/torch210-cxx11-cu126-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/muon.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py b/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index 8f98e5170bbbee1e7030043a85edac0e7500b08c..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:017323d479e8fbd3ed1f550f95fc4ba9f2e304dbe9351c0eaa75543ebe775e18 -size 2004144 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..2926579ac6973002c2d4067df7122743ecd1567d --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2af397ae01c8c01ee0e879f6812bd9df55d152afbcc6713f5c1987d5bce7793b +size 2004144 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py b/build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/muon.py b/build/torch210-cxx11-cu128-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/muon.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py b/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index 69f073c51c32940bfa69db0e73d866bd2109dac2..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:783a161f2d28e4244226c9d6e59ac33f74f7a79aad17c06e8ce027dd6182e03c -size 2004728 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..570449600e33b32b83585ec10fe0593b4c4318bc --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45eef069a7caa85678cd1e05f0c60c5cfbc676dc93a1bcb31e55eb34730aa469 +size 2004728 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py b/build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/muon.py b/build/torch210-cxx11-cu130-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/muon.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py b/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index bf5569466f9bc68669f2f0730af8200f8b4ee267..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8ec2fcc8a9dc8a1e4aa4e925eaee33613177873e474e8d627bf844dae80f5f8b -size 1866400 diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d24aaf499164890a854ed9b06cfb6439f0c392a9 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:055206c495ecade2fe4b5427db34f0a48152174e79808cbe1ce7d7ca86d32396 +size 1866400 diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py b/build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/muon.py b/build/torch210-cxx11-rocm70-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/muon.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py b/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index 5c860d5ebc24f59df20241eeee3d0a2211cd519f..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:18373b2e448071735ce724008122f179dd814986925c9cf0fc03f32201b2b1fa -size 1866112 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..482f777c94bf8d03311ce7c0be278ece8264857f --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:315ff09ffa88ec806cb8abe49edb2ca6951e9ac34be3d3e10f159093f9576ee0 +size 1866112 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py b/build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/muon.py b/build/torch210-cxx11-rocm71-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/muon.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py b/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index 335bba4b9fd49cfedcdc7364de8cf2e343290c81..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d2db9c7fb764a1fae1872779bc9ffac2aff18d14a238111d6b8b53b7d3dfa0d3 -size 1936664 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..49cb725bbd4c9011ecc4ad53c60007d4b39e93c4 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9a7c1beffbad405ef7d6f46f44cf9c6671d119e04a340b54c8f4c8f9d699caf +size 1936664 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py b/build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/muon.py b/build/torch28-cxx11-cu126-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/muon.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py b/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index b34eb7954b7ffac75892de5b572008620001470f..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b1ddfe7e38a9213d5dede8052c81b78eca952aef122d4da919950ff504dc3908 -size 1999872 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..99e4220345748e9633a27b083af5e5ac2605be0b --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:090f5a44cdfa4554147159cc36bb7e8ee9dba1ffb1fea4825aa838461fdaddf9 +size 1999872 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py b/build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/muon.py b/build/torch28-cxx11-cu128-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/muon.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py b/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index 433642a40078da8da028d3a3a7765aa080e7fbb4..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:db68ba26f1b022f56a5ab4e6e0204bf26df8922750f32f21be0ad76e2674b717 -size 1999872 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e1e05434ff697a1f0daae635177d29e1b6f8531b --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46baa92bf8f5ec5913df4081a01f662049fda475eb01bc7ed0f6154755fa88d5 +size 1999872 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py b/build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/muon.py b/build/torch28-cxx11-cu129-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/muon.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py b/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index 2b9d7cf21cd452edc218f40112277ade1b8a0ef4..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5a6a8788f055b22d594330fc06487ae2c6eeb2b64e0ab0132b68036a78560cf6 -size 1865080 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..0df93dd2ea24c80cbed1f029804c4e4e480a140e --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bcf5b8838dfaf6e81fdbd52ff4638ca76abaa678f7c2cbd81cf03dc72f9cd5d2 +size 1865080 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/cpu_offload.py b/build/torch28-cxx11-rocm63-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/cpu_offload.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/muon.py b/build/torch28-cxx11-rocm63-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/muon.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py b/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index ca6aaefd80e0a6fb2554b0d7834380ad039c55cb..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6e5e17fd042010ec06456f5885603c4e38476981d43adb1cc99ea6dbe5f57c6f -size 1865168 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..9d0435e26c1566f1da7309ceeb31f49e290217cb --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f43a2025f967fcd94fcada1e8a708956f07774c522f156caab135d7162c7a91 +size 1865168 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/cpu_offload.py b/build/torch28-cxx11-rocm64-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/cpu_offload.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/muon.py b/build/torch28-cxx11-rocm64-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/muon.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py b/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index f19507bac03da8dc8cef114c88bf669bee77834d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f6b4b64b9e80383e0e1e9d5482f39dc28256bf901211d60deda16d905f198e70 -size 1936664 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..8b49da51c7e1e432368980c185a0456d90b49f70 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ba2d5675067b94d6327adc7d124fa9ac534af8236b6783981e13b47a5ee603b +size 1936664 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/cpu_offload.py b/build/torch29-cxx11-cu126-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/cpu_offload.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/muon.py b/build/torch29-cxx11-cu126-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/muon.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py b/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index be6c38f77dfea6126d9e420956e78608f814e17b..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:52bc357d5a5d09094e142d1bf87e9a2ba819a6770b8b34ab9469ceb2414ad29e -size 1999872 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..0ecf56af29245daebb0f2048f89e30d786c9a8a7 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e9768c250f77d75c777d94e1f7817cb24372b8d644bc28d3edfa1a829317272 +size 1999872 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/cpu_offload.py b/build/torch29-cxx11-cu128-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/cpu_offload.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/muon.py b/build/torch29-cxx11-cu128-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/muon.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py b/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index 420a438b4ae1bd20e685b29cd293000be31bdcca..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:faa0fc60fea48f7ba85933fb35d3c96afb6f5b357d4e28565e8112848fed890a -size 2000456 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..c3a9a8d14613666ccf77a88386d75632e689b3bd --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b73946a6e5c0366cfb776d7d553e6343a354392da029564aff8ad0d961ffa25b +size 2000456 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/cpu_offload.py b/build/torch29-cxx11-cu130-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/cpu_offload.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/muon.py b/build/torch29-cxx11-cu130-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/muon.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py b/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index e1e49bb3bce5f978495a622f9b64993c4c024136..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d8cf990229d1c3dc8378e74487af28ab48ffb91ec12869ddb0839d3b4cddc03e -size 1865112 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..1c0c2fe89693c7112bf2ec2e0ea203ba1a8292bb --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9bdbe93877a276dbfdd4c596b788d70ed61804c4c5bc555259c6a3be0e9ec8fe +size 1865112 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/cpu_offload.py b/build/torch29-cxx11-rocm63-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/cpu_offload.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/muon.py b/build/torch29-cxx11-rocm63-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/muon.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py b/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py index 2b9a835b2bee66a402df46da0550a602812ddece..034bff088659b5df6f6d401feb18c89dc5f33b29 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_38f9b8e_dirty -ops = torch.ops._optimizer_38f9b8e_dirty +from . import _optimizer_8d53b78_dirty +ops = torch.ops._optimizer_8d53b78_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_38f9b8e_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_8d53b78_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so deleted file mode 100755 index 32ee070eb44f56f58078c35e8ad32fcc4d39e44c..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_38f9b8e_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fbdc2be035c0380bdd9ea10a0f913ecf5b6be29d4d7d74e1bd4056143393f28d -size 1865232 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..b11d926ee80d66a0e597271dd7ecf3e693493358 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_8d53b78_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3186e65c7ef03c5272229d1b155fa3b4309ae2c119f149e20bbae41c64cd754 +size 1865232 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/cpu_offload.py b/build/torch29-cxx11-rocm64-x86_64-linux/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/cpu_offload.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/muon.py b/build/torch29-cxx11-rocm64-x86_64-linux/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/muon.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py b/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged) diff --git a/docs/muon-clip.md b/docs/muon-clip.md new file mode 100644 index 0000000000000000000000000000000000000000..bcb77228e11beb4a2407085f34d19ebc12e54d56 --- /dev/null +++ b/docs/muon-clip.md @@ -0,0 +1,317 @@ +# QK-Clip for MuonClip Optimizer (MLA) + +> Reference: [Kimi K2 Technical Report](https://arxiv.org/pdf/2507.20534), Section 2.1, Algorithm 1 + +## 개요 + +QK-Clip은 Muon optimizer에서 발생하는 attention logit explosion을 방지하기 위한 **weight rescaling** 기법이다. +forward/backward에는 개입하지 않고, optimizer step **이후**에 weight를 rescale하여 logit 성장을 원천 차단한다. + +## Algorithm 1: MuonClip + +``` +for each training step t: + // 1. Muon optimizer step + for each weight W: + Mt = µ·Mt-1 + Gt + Ot = Newton-Schulz(Mt) · √max(n,m) · 0.2 + Wt = Wt-1 - η·(Ot + λ·Wt-1) + + // 2. QK-Clip + for each attention head h: + S^h_max ← forward에서 기록한 head h의 max pre-softmax logit + if S^h_max > τ: + γ ← τ / S^h_max + W^h_qc ← W^h_qc · √γ (query compressed, q_nope) + W^h_kc ← W^h_kc · √γ (key compressed, k_nope) + W^h_qr ← W^h_qr · γ (query rotary, q_pe) + // k_R (shared rotary, k_pe): 안 건드림 +``` + +## 기존 코드 → MLA 수도코드 + +### 현재 코드 구조 (MHA/GQA) + +``` +parse_qk_layer(name) → wq/wk 여부 판별, layer index 추출 +get_qk_clip_info(config, n) → QKClipInfo (kind, indices, head_dim, threshold, logit) +compute_scales(p, info) → per-head √γ scales 텐서 반환 +qk_clip(p, scales, head_dim) → W.view(-1, head_dim, in_dim).mul_(scales) +``` + +현재 코드는 head_dim이 균일하고, Q/K weight 전체에 동일한 √γ를 적용한다. + +### MLA에서 달라지는 점 + +| 항목 | MHA/GQA (현재) | MLA | +|---|---|---| +| Q weight | `wq` / `q_proj` | `wq_b` (up-proj from LoRA) | +| K weight | `wk` / `k_proj` | `wkv_b` (k_nope + v 합쳐져 있음) | +| Q head stride | `qk_head_dim` (균일) | `qk_head_dim` = `qk_nope_head_dim + qk_rope_head_dim` | +| K head stride | `qk_head_dim` (균일) | `kv_stride` = `qk_nope_head_dim + v_head_dim` | +| Q scaling | 전체 √γ | nope → √γ, rope → γ (서로 다름) | +| K scaling | 전체 √γ | k_nope → √γ, v → 1.0 (부분만) | +| shared k_pe | 없음 | `wkv_a` 뒷부분, 안 건드림 | + +### 수도코드: parse_qk_layer (MLA 확장) + +```python +def parse_qk_layer(name: str) -> tuple[str | None, int]: + parts = normalize_fqn(name).split('.') + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + # MHA/GQA: wq, wk, q_proj, k_proj + # MLA: wq_b (Q up-proj), wkv_b (KV up-proj) + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): + return kind, layer_idx + + return None, -1 +``` + +### 수도코드: QKClipInfo (MLA 확장) + +```python +@dataclass +class QKClipInfo: + kind: str | None # 'wq_b' or 'wkv_b' (MLA) / 'wq','wk' (MHA) + indices: list[int] # clipping 대상 head indices + head_dim: int # 기존 MHA용 (uniform stride) + threshold: float + logit: torch.Tensor | None + + # MLA 전용 필드 + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 +``` + +### 수도코드: get_qk_clip_info (MLA 확장) + +```python +def get_qk_clip_info(clip_config, n, qk_logits): + if clip_config is None: + return None + + threshold = clip_config['threshold'] + kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + if isinstance(logit, DTensor): + logit = logit.full_tensor() + + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=clip_config['head_dim'], # qk_head_dim (for wq_b) + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + # 기존 MHA/GQA 경로 + return QKClipInfo( + kind=kind, indices=indices, + head_dim=clip_config['head_dim'], + threshold=threshold, logit=logit, + ) +``` + +### 수도코드: compute_scales (MLA 확장) + +기존과 동일하게 per-head γ를 계산한다. (γ 결정은 MHA와 동일) +달라지는 건 `qk_clip` 적용 시 head 내부를 sub-region별로 나눠서 다른 변환을 쓰는 것이다. + +```python +def compute_scales(p, qk_clip_state): + """기존 코드와 동일. per-head √γ 반환.""" + kind = qk_clip_state.kind + indices = qk_clip_state.indices + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + head_scales = {} + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) # √γ + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale + + if not head_scales: + return None + + H_global = p.shape[0] // qk_clip_state.head_dim # MLA: head_dim = qk_head_dim or kv_stride + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale # √γ_h + + return scales_full +``` + +### 수도코드: qk_clip (MLA 확장) + +per-head scales(√γ)는 동일하게 받되, head 내부 sub-region에 다른 함수를 적용한다. + +```python +def qk_clip(p, scales, head_dim, is_mla=False, kind=None, info=None): + """ + scales: [n_heads] 텐서, 각 원소 = √γ_h + + is_mla=False: 기존 MHA/GQA (head 내 uniform √γ) + is_mla=True: MLA (head 내 sub-region별 다른 변환) + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not is_mla: + # 기존: 모든 행에 √γ 균일 적용 + W.view(-1, head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: head별로 sub-region 분리 적용 + if kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_rope = info.qk_rope_head_dim + qk_head_dim = qk_nope + qk_rope + + for h in range(len(scales)): + sqrt_gamma = scales[h].item() + if sqrt_gamma >= 1.0: + continue + gamma = sqrt_gamma * sqrt_gamma # √γ → γ + s = h * qk_head_dim + + W[s : s + qk_nope] *= sqrt_gamma # q_nope → √γ + W[s + qk_nope : s + qk_head_dim] *= gamma # q_pe → γ + + elif kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + + for h in range(len(scales)): + sqrt_gamma = scales[h].item() + if sqrt_gamma >= 1.0: + continue + s = h * kv_stride + + W[s : s + qk_nope] *= sqrt_gamma # k_nope → √γ + # v 행: 안 건드림 +``` + +### 수도코드: GQA에서 wkv_b indices 처리 + +Q head → KV head 매핑이 필요하다. +여러 Q head가 같은 KV head를 공유하므로, **group 내 최소 gamma** 기준으로 한 번만 적용해야 한다. + +```python +def build_k_indices_for_mla(clip_config, n_heads, n_kv_heads): + """ + Q head 기준 logit으로부터 KV head indices를 생성한다. + q_indices가 Q head index 기준이라면, + k_indices는 대응되는 KV head index로 변환해야 한다. + + 주의: 같은 KV head에 매핑되는 여러 Q head 중 + 가장 큰 logit (= 가장 작은 gamma)을 사용해야 한다. + """ + heads_per_kv = n_heads // n_kv_heads + q_indices = clip_config.get('q_indices', list(range(n_heads))) + + # Q head → KV head 매핑 + # logit 텐서에서 같은 kv_head에 대응되는 Q head들 중 max를 취하는 것은 + # compute_scales_mla 내부에서 min(gamma) 로 처리됨 + + k_indices = [] + seen = set() + for q_idx in q_indices: + kv_idx = q_idx // heads_per_kv + if kv_idx not in seen: + k_indices.append(kv_idx) + seen.add(kv_idx) + + return k_indices +``` + +### 수도코드: 호출 흐름 (통합) + +```python +# optimizer step 이후 호출되는 부분 (기존 코드 구조 유지) + +for name, param in model.named_parameters(): + info = get_qk_clip_info(clip_config, name, qk_logits) + if info is None or info.kind is None: + continue + + scales = compute_scales(param, info) # per-head √γ (MHA/MLA 공통) + if scales is not None: + qk_clip(param, scales, info.head_dim, + is_mla=info.is_mla, kind=info.kind, info=info) +``` + +### 수도코드: clip_config 예시 + +```python +# MHA/GQA (기존) +clip_config = { + 'head_dim': 128, + 'threshold': 100.0, + 'q_indices': list(range(n_heads)), + 'k_indices': list(range(n_kv_heads)), +} + +# MLA (확장) +clip_config = { + 'is_mla': True, + 'head_dim': 192, # qk_head_dim (= qk_nope + qk_rope) + 'qk_nope_head_dim': 128, + 'qk_rope_head_dim': 64, + 'v_head_dim': 128, + 'threshold': 100.0, + 'q_indices': list(range(n_heads)), + 'k_indices': list(range(n_kv_heads)), # build_k_indices_for_mla로 생성 +} +``` + +## 행 인덱스 매핑 테이블 + +| 알고리즘 기호 | 텐서 | 행 범위 | scale | +|---|---|---|---| +| W^h_qc | `wq_b.weight` | `[h*qk_head_dim : h*qk_head_dim + qk_nope_head_dim]` | √γ | +| W^h_qr | `wq_b.weight` | `[h*qk_head_dim + qk_nope_head_dim : (h+1)*qk_head_dim]` | γ | +| W^h_kc | `wkv_b.weight` | `[kv_h*kv_stride : kv_h*kv_stride + qk_nope_head_dim]` | √γ | +| k_R | `wkv_a` output 뒷부분 | - | 안 건드림 | + +- `kv_stride = qk_nope_head_dim + v_head_dim` +- `kv_h = h // (n_heads // n_kv_heads)` (GQA head 매핑) + +## 하이퍼파라미터 + +| 파라미터 | 값 | 비고 | +|---|---|---| +| τ (threshold) | 100 | K2 full-scale 학습 | +| τ (aggressive) | 30 | 소규모 ablation, 성능 저하 없음 확인 | + +## 참고사항 + +- **Self-deactivation**: K2에서 초기 70k step 동안 12.7%의 head만 trigger됨. 이후 모든 head의 S_max가 τ 아래로 내려가면서 자연스럽게 비활성화. +- **DP/TP 환경**: S^h_max를 all-reduce로 모든 rank에서 max 수집 필요. +- **GQA 중복 적용 방지**: 같은 KV head를 공유하는 Q head group에서 가장 작은 gamma(= 가장 큰 logit)를 기준으로 KV weight를 한 번만 scaling. `compute_scales_mla`에서 `min(gamma)` 로직으로 처리. +- **wq_b_gate**: attention logit이 아닌 output gate에만 관여하므로 QK-Clip 대상 아님. +- **기존 logit soft-cap**: forward-level safety net으로 남겨두되, optimizer-level QK-Clip을 추가하는 것이 논문의 접근법. diff --git a/test/test_cpu_offload.py b/test/test_cpu_offload.py index 54f2f4dc26815e7b7a07a9e46ecbecea90767f3f..0927e5cd490d0340d25b48fd4fb3c34e1c893a2d 100644 --- a/test/test_cpu_offload.py +++ b/test/test_cpu_offload.py @@ -11,7 +11,6 @@ Tests: import copy import logging -import sys import pytest import torch @@ -30,8 +29,7 @@ def _setup(): def _make_mesh(world_size): - return dist.init_device_mesh("cuda", (world_size, ), - mesh_dim_names=("dp", )) + return dist.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) def test_correctness(rank, world_size): @@ -49,12 +47,11 @@ def test_correctness(rank, world_size): num_steps = 3 # Pre-generate all data on all ranks (same seed → same values). - full_params = [ - torch.randn(dim0, dim1, device="cuda") for _ in range(num_params) + full_params = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)] + full_grads = [ + [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)] + for _ in range(num_steps) ] - full_grads = [[ - torch.randn(dim0, dim1, device="cuda") for _ in range(num_params) - ] for _ in range(num_steps)] def make_optimizer(cpu_offload): params, names = [], [] @@ -63,17 +60,19 @@ def test_correctness(rank, world_size): p = torch.nn.Parameter(dt) params.append(p) names.append(f"layer.{i}.weight") - param_groups = [{ - "params": params, - "names": names, - "use_muon": True, - "lr": 0.02, - "weight_decay": 0.01, - "momentum": 0.95, - "nesterov": True, - "ns_steps": 5, - "none_grad": False, - }] + param_groups = [ + { + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + } + ] optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) if cpu_offload: optim.turn_on_cpu_offload() @@ -122,22 +121,25 @@ def test_memory(rank, world_size): full = torch.randn(dim0, dim1, device="cuda") dt = distribute_tensor(full, mesh, [Shard(0)]) p = torch.nn.Parameter(dt) - p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), - mesh, [Shard(0)]) + p.grad = distribute_tensor( + torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] + ) params.append(p) names.append(f"layer.{i}.weight") - param_groups = [{ - "params": params, - "names": names, - "use_muon": True, - "lr": 0.02, - "weight_decay": 0.01, - "momentum": 0.95, - "nesterov": True, - "ns_steps": 5, - "none_grad": False, - }] + param_groups = [ + { + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + } + ] optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) optim.turn_on_cpu_offload() @@ -153,7 +155,8 @@ def test_memory(rank, world_size): local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf assert local_buf.untyped_storage().size() == 0, ( f"Expected freed GPU storage after offload, got " - f"{local_buf.untyped_storage().size()} bytes") + f"{local_buf.untyped_storage().size()} bytes" + ) # Verify CPU pool has pinned buffers. pool = optim._cpu_offload_pool @@ -163,8 +166,9 @@ def test_memory(rank, world_size): # Run another step to verify reload + compute + offload cycle works. for p in params: - p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), - mesh, [Shard(0)]) + p.grad = distribute_tensor( + torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] + ) optim.step() torch.cuda.synchronize() @@ -213,21 +217,26 @@ def test_adamw_offload(rank, world_size): adamw_names.append(f"layer.{i}.bias") # Pre-generate grads. - muon_grads = [[torch.randn(64, 128, device="cuda") for _ in range(4)] - for _ in range(num_steps)] - adamw_grads = [[torch.randn(128, device="cuda") for _ in range(3)] - for _ in range(num_steps)] + muon_grads = [ + [torch.randn(64, 128, device="cuda") for _ in range(4)] + for _ in range(num_steps) + ] + adamw_grads = [ + [torch.randn(128, device="cuda") for _ in range(3)] for _ in range(num_steps) + ] def make_optimizer(cpu_offload): mp = [ torch.nn.Parameter( - distribute_tensor(p.data.full_tensor().clone(), mesh, - [Shard(0)])) for p in muon_params + distribute_tensor(p.data.full_tensor().clone(), mesh, [Shard(0)]) + ) + for p in muon_params ] ap = [ torch.nn.Parameter( - distribute_tensor(p.data.full_tensor().clone(), mesh, - [Shard(0)])) for p in adamw_params + distribute_tensor(p.data.full_tensor().clone(), mesh, [Shard(0)]) + ) + for p in adamw_params ] param_groups = [ { @@ -297,7 +306,8 @@ def test_adamw_offload(rank, world_size): t = state[key] local_t = t._local_tensor if isinstance(t, DTensor) else t assert local_t.untyped_storage().size() == 0, ( - f"AdamW {key} storage not freed after offload") + f"AdamW {key} storage not freed after offload" + ) set_ns_compile(True) if rank == 0: @@ -325,22 +335,25 @@ def test_memory_savings(rank, world_size): full = torch.randn(dim0, dim1, device="cuda") dt = distribute_tensor(full, mesh, [Shard(0)]) p = torch.nn.Parameter(dt) - p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), - mesh, [Shard(0)]) + p.grad = distribute_tensor( + torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] + ) params.append(p) names.append(f"layer.{i}.weight") - param_groups = [{ - "params": params, - "names": names, - "use_muon": True, - "lr": 0.02, - "weight_decay": 0.01, - "momentum": 0.95, - "nesterov": True, - "ns_steps": 5, - "none_grad": False, - }] + param_groups = [ + { + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + } + ] optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) if cpu_offload: optim.turn_on_cpu_offload() @@ -357,17 +370,16 @@ def test_memory_savings(rank, world_size): mem_with_offload = run_step(True) if rank == 0: - logger.info("Memory without offload: %.2f MB", - mem_no_offload / 1024**2) - logger.info("Memory with offload: %.2f MB", - mem_with_offload / 1024**2) + logger.info("Memory without offload: %.2f MB", mem_no_offload / 1024**2) + logger.info("Memory with offload: %.2f MB", mem_with_offload / 1024**2) saved = mem_no_offload - mem_with_offload logger.info("Memory saved: %.2f MB", saved / 1024**2) assert mem_with_offload < mem_no_offload, ( f"Expected memory reduction with CPU offload. " f"Without: {mem_no_offload / 1024**2:.2f} MB, " - f"With: {mem_with_offload / 1024**2:.2f} MB") + f"With: {mem_with_offload / 1024**2:.2f} MB" + ) set_ns_compile(True) if rank == 0: @@ -388,12 +400,11 @@ def test_toggle_correctness(rank, world_size): num_params = 4 num_steps = 6 - full_params = [ - torch.randn(dim0, dim1, device="cuda") for _ in range(num_params) + full_params = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)] + full_grads = [ + [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)] + for _ in range(num_steps) ] - full_grads = [[ - torch.randn(dim0, dim1, device="cuda") for _ in range(num_params) - ] for _ in range(num_steps)] def make_optimizer(): params, names = [], [] @@ -402,17 +413,19 @@ def test_toggle_correctness(rank, world_size): p = torch.nn.Parameter(dt) params.append(p) names.append(f"layer.{i}.weight") - param_groups = [{ - "params": params, - "names": names, - "use_muon": True, - "lr": 0.02, - "weight_decay": 0.01, - "momentum": 0.95, - "nesterov": True, - "ns_steps": 5, - "none_grad": False, - }] + param_groups = [ + { + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + } + ] optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) return optim, params @@ -433,8 +446,7 @@ def test_toggle_correctness(rank, world_size): for i in range(num_params): g = full_grads[step_idx][i] params_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) - params_toggle[i].grad = distribute_tensor( - g.clone(), mesh, [Shard(0)]) + params_toggle[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) optim_ref.step() optim_toggle.step() @@ -445,8 +457,11 @@ def test_toggle_correctness(rank, world_size): torch.testing.assert_close(ref_full, tog_full, atol=0, rtol=0) if rank == 0: - logger.info("Step %d (offload=%s): toggle correctness OK", - step_idx, optim_toggle.cpu_offload) + logger.info( + "Step %d (offload=%s): toggle correctness OK", + step_idx, + optim_toggle.cpu_offload, + ) set_ns_compile(True) if rank == 0: @@ -477,17 +492,19 @@ def test_leak(rank, world_size): params.append(p) names.append(f"layer.{i}.weight") - param_groups = [{ - "params": params, - "names": names, - "use_muon": True, - "lr": 0.02, - "weight_decay": 0.01, - "momentum": 0.95, - "nesterov": True, - "ns_steps": 5, - "none_grad": False, - }] + param_groups = [ + { + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + } + ] optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) optim.turn_on_cpu_offload() @@ -502,8 +519,9 @@ def test_leak(rank, world_size): for step_idx in range(num_steps): for p in params: - p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), - mesh, [Shard(0)]) + p.grad = distribute_tensor( + torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] + ) optim.step() torch.cuda.synchronize() @@ -518,8 +536,12 @@ def test_leak(rank, world_size): cpu_after_warmup = cpu_mem if rank == 0 and step_idx % 10 == 0: - logger.info("Step %d: GPU alloc=%.2f MB, CPU RSS=%.2f MB", - step_idx, gpu_mem / (1024**2), cpu_mem) + logger.info( + "Step %d: GPU alloc=%.2f MB, CPU RSS=%.2f MB", + step_idx, + gpu_mem / (1024**2), + cpu_mem, + ) # Final measurements. torch.cuda.synchronize() @@ -527,15 +549,23 @@ def test_leak(rank, world_size): cpu_final = get_cpu_rss_mb() if rank == 0: - logger.info("After %d steps: GPU alloc=%.2f MB, CPU RSS=%.2f MB", - num_steps, gpu_final / (1024**2), cpu_final) - logger.info("Warmup baseline: GPU alloc=%.2f MB, CPU RSS=%.2f MB", - gpu_after_warmup / (1024**2), cpu_after_warmup) + logger.info( + "After %d steps: GPU alloc=%.2f MB, CPU RSS=%.2f MB", + num_steps, + gpu_final / (1024**2), + cpu_final, + ) + logger.info( + "Warmup baseline: GPU alloc=%.2f MB, CPU RSS=%.2f MB", + gpu_after_warmup / (1024**2), + cpu_after_warmup, + ) # GPU memory should not grow beyond warmup baseline. assert gpu_final <= gpu_after_warmup, ( f"GPU memory leak detected! Warmup: {gpu_after_warmup / 1024**2:.2f} MB, " - f"Final: {gpu_final / 1024**2:.2f} MB") + f"Final: {gpu_final / 1024**2:.2f} MB" + ) # CPU RSS should not grow more than 50 MB over warmup (allows for minor # Python/CUDA runtime overhead but catches real leaks). @@ -543,12 +573,12 @@ def test_leak(rank, world_size): assert cpu_growth < 50, ( f"CPU memory leak detected! Growth: {cpu_growth:.2f} MB over " f"{num_steps - 2} steps (warmup={cpu_after_warmup:.2f} MB, " - f"final={cpu_final:.2f} MB)") + f"final={cpu_final:.2f} MB)" + ) set_ns_compile(True) if rank == 0: - logger.info("PASSED: test_leak (GPU stable, CPU growth=%.2f MB)", - cpu_growth) + logger.info("PASSED: test_leak (GPU stable, CPU growth=%.2f MB)", cpu_growth) def test_state_dict_save_load(rank, world_size): @@ -576,26 +606,28 @@ def test_state_dict_save_load(rank, world_size): num_steps = 3 # Pre-generate all data. - muon_init = [ - torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon) - ] + muon_init = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon)] adamw_init = [torch.randn(dim1, device="cuda") for _ in range(num_adamw)] - all_grads_muon = [[ - torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon) - ] for _ in range(num_steps * 2)] - all_grads_adamw = [[ - torch.randn(dim1, device="cuda") for _ in range(num_adamw) - ] for _ in range(num_steps * 2)] + all_grads_muon = [ + [torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon)] + for _ in range(num_steps * 2) + ] + all_grads_adamw = [ + [torch.randn(dim1, device="cuda") for _ in range(num_adamw)] + for _ in range(num_steps * 2) + ] def make_optimizer(cpu_offload): mp = [ torch.nn.Parameter( - distribute_tensor(muon_init[i].clone(), mesh, [Shard(0)])) + distribute_tensor(muon_init[i].clone(), mesh, [Shard(0)]) + ) for i in range(num_muon) ] ap = [ torch.nn.Parameter( - distribute_tensor(adamw_init[i].clone(), mesh, [Shard(0)])) + distribute_tensor(adamw_init[i].clone(), mesh, [Shard(0)]) + ) for i in range(num_adamw) ] param_groups = [ @@ -634,15 +666,17 @@ def test_state_dict_save_load(rank, world_size): for step_idx in range(num_steps): for i in range(num_muon): mp_off[i].grad = distribute_tensor( - all_grads_muon[step_idx][i].clone(), mesh, [Shard(0)]) + all_grads_muon[step_idx][i].clone(), mesh, [Shard(0)] + ) for i in range(num_adamw): ap_off[i].grad = distribute_tensor( - all_grads_adamw[step_idx][i].clone(), mesh, [Shard(0)]) + all_grads_adamw[step_idx][i].clone(), mesh, [Shard(0)] + ) optim_off.step() with pytest.raises( - RuntimeError, - match="turn_off_cpu_offload\\(\\) before checkpoint save"): + RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint save" + ): optim_off.state_dict() optim_off.turn_off_cpu_offload() @@ -654,7 +688,8 @@ def test_state_dict_save_load(rank, world_size): if isinstance(val, torch.Tensor) and val.is_floating_point(): assert val.untyped_storage().size() > 0, ( f"state_dict() returned empty storage for key '{key}' — " - f"offload reload is broken") + f"offload reload is broken" + ) if rank == 0: logger.info("state_dict() contains valid (non-empty) tensors") @@ -689,8 +724,8 @@ def test_state_dict_save_load(rank, world_size): for i in range(num_adamw): ap_ref[i].data = ap_off[i].data.clone() with pytest.raises( - RuntimeError, - match="turn_off_cpu_offload\\(\\) before checkpoint load"): + RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load" + ): optim_ref.load_state_dict(copy.deepcopy(sd_off)) optim_ref.turn_off_cpu_offload() optim_ref.load_state_dict(copy.deepcopy(sd_off)) @@ -714,8 +749,8 @@ def test_state_dict_save_load(rank, world_size): if flat_key in flat_target: param_state[key] = flat_target[flat_key] with pytest.raises( - RuntimeError, - match="turn_off_cpu_offload\\(\\) before checkpoint load"): + RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load" + ): optim_resumed.load_state_dict(copy.deepcopy(sd_loaded)) optim_resumed.turn_off_cpu_offload() optim_resumed.load_state_dict(sd_loaded) @@ -760,7 +795,8 @@ def test_state_dict_save_load(rank, world_size): buf = state["momentum_buffer"] local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf assert local_buf.untyped_storage().size() == 0, ( - "Resumed optimizer should have offloaded state after step()") + "Resumed optimizer should have offloaded state after step()" + ) set_ns_compile(True) if rank == 0: @@ -785,22 +821,25 @@ def test_checkpoint_memory(rank, world_size): full = torch.randn(dim0, dim1, device="cuda") dt = distribute_tensor(full, mesh, [Shard(0)]) p = torch.nn.Parameter(dt) - p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), - mesh, [Shard(0)]) + p.grad = distribute_tensor( + torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] + ) params.append(p) names.append(f"layer.{i}.weight") - param_groups = [{ - "params": params, - "names": names, - "use_muon": True, - "lr": 0.02, - "weight_decay": 0.01, - "momentum": 0.95, - "nesterov": True, - "ns_steps": 5, - "none_grad": False, - }] + param_groups = [ + { + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + } + ] optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) optim.turn_on_cpu_offload() @@ -822,13 +861,14 @@ def test_checkpoint_memory(rank, world_size): if rank == 0: logger.info( - "After step (offloaded): GPU alloc=%.2f MB, " - "expected state size=%.2f MB", mem_after_step / 1024**2, - state_bytes / 1024**2) + "After step (offloaded): GPU alloc=%.2f MB, expected state size=%.2f MB", + mem_after_step / 1024**2, + state_bytes / 1024**2, + ) with pytest.raises( - RuntimeError, - match="turn_off_cpu_offload\\(\\) before checkpoint save"): + RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint save" + ): optim.state_dict() optim.turn_off_cpu_offload() @@ -837,48 +877,57 @@ def test_checkpoint_memory(rank, world_size): sd_for_load = copy.deepcopy(optim.state_dict()) if rank == 0: - logger.info("After turn_off_cpu_offload: GPU alloc=%.2f MB", - mem_after_turn_off / 1024**2) + logger.info( + "After turn_off_cpu_offload: GPU alloc=%.2f MB", + mem_after_turn_off / 1024**2, + ) assert mem_after_turn_off > mem_after_step, ( f"turn_off_cpu_offload() should reload states to GPU. " f"After offload: {mem_after_step / 1024**2:.2f} MB, " - f"After turn_off: {mem_after_turn_off / 1024**2:.2f} MB") + f"After turn_off: {mem_after_turn_off / 1024**2:.2f} MB" + ) optim.turn_on_cpu_offload() torch.cuda.synchronize() mem_after_turn_on = torch.cuda.memory_allocated() if rank == 0: - logger.info("After turn_on_cpu_offload: GPU alloc=%.2f MB", - mem_after_turn_on / 1024**2) + logger.info( + "After turn_on_cpu_offload: GPU alloc=%.2f MB", mem_after_turn_on / 1024**2 + ) assert mem_after_turn_on <= mem_after_step + 4 * 1024 * 1024, ( f"turn_on_cpu_offload() should return memory to offloaded level. " f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " - f"got {mem_after_turn_on / 1024**2:.2f} MB") + f"got {mem_after_turn_on / 1024**2:.2f} MB" + ) for p in params: - p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), - mesh, [Shard(0)]) + p.grad = distribute_tensor( + torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] + ) optim.step() torch.cuda.synchronize() mem_after_next_step = torch.cuda.memory_allocated() if rank == 0: - logger.info("After next step (re-offloaded): GPU alloc=%.2f MB", - mem_after_next_step / 1024**2) + logger.info( + "After next step (re-offloaded): GPU alloc=%.2f MB", + mem_after_next_step / 1024**2, + ) # Allow 4 MB tolerance for CUDA allocator fragmentation. assert mem_after_next_step <= mem_after_step + 4 * 1024 * 1024, ( f"Memory should return to offloaded level after step(). " f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " - f"got {mem_after_next_step / 1024**2:.2f} MB") + f"got {mem_after_next_step / 1024**2:.2f} MB" + ) with pytest.raises( - RuntimeError, - match="turn_off_cpu_offload\\(\\) before checkpoint load"): + RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load" + ): optim.load_state_dict(copy.deepcopy(sd_for_load)) optim.turn_off_cpu_offload() @@ -888,24 +937,30 @@ def test_checkpoint_memory(rank, world_size): mem_after_load = torch.cuda.memory_allocated() if rank == 0: - logger.info("After load_state_dict with offload disabled: GPU alloc=%.2f MB", - mem_after_load / 1024**2) + logger.info( + "After load_state_dict with offload disabled: GPU alloc=%.2f MB", + mem_after_load / 1024**2, + ) assert mem_after_load >= mem_after_turn_off, ( - "Loaded optimizer state should stay on GPU while offload is disabled") + "Loaded optimizer state should stay on GPU while offload is disabled" + ) optim.turn_on_cpu_offload() torch.cuda.synchronize() pool = optim._cpu_offload_pool - assert pool._initialized, "Offload pool should be initialized after re-enabling offload" + assert pool._initialized, ( + "Offload pool should be initialized after re-enabling offload" + ) for grp in pool._groups.values(): assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned" # Step 5: verify the loaded optimizer can still step correctly. for p in params: - p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), - mesh, [Shard(0)]) + p.grad = distribute_tensor( + torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] + ) optim.step() torch.cuda.synchronize() @@ -913,7 +968,8 @@ def test_checkpoint_memory(rank, world_size): assert mem_final <= mem_after_step + 4 * 1024 * 1024, ( f"Final memory should be at offloaded level. " f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " - f"got {mem_final / 1024**2:.2f} MB") + f"got {mem_final / 1024**2:.2f} MB" + ) set_ns_compile(True) if rank == 0: diff --git a/torch-ext/optimizer/cpu_offload.py b/torch-ext/optimizer/cpu_offload.py index 5ffa230a95db4749f1b10a400c60d36c1bd33368..fb5e69154a1d4a6c884491413a37a9acf0f66c80 100644 --- a/torch-ext/optimizer/cpu_offload.py +++ b/torch-ext/optimizer/cpu_offload.py @@ -93,10 +93,7 @@ class CPUOffloadPool: indices.append(idx) offsets.append((off, n)) off += n - cpu_flat = torch.empty(off, - dtype=dtype, - device="cpu", - pin_memory=True) + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) self._groups[dtype] = { "indices": indices, "offsets": offsets, @@ -140,8 +137,7 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - cpu_flat[off:off + n].copy_(local.reshape(-1), - non_blocking=True) + cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True) offloaded_bytes += grp["total"] * cpu_flat.element_size() @@ -159,8 +155,10 @@ class CPUOffloadPool: ) if not self._logged: - logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", - offloaded_bytes / (1024**2)) + logger.info( + "[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2), + ) # ------------------------------------------------------------------ def reload(self): @@ -198,12 +196,11 @@ class CPUOffloadPool: for i, mgd_idx in enumerate(indices): local = self._local(self._managed[mgd_idx]) off, n = offsets[i] - local.reshape(-1).copy_(cpu_flat[off:off + n], - non_blocking=True) + local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True) reloaded_bytes += grp["total"] * cpu_flat.element_size() if not self._logged: - logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", - reloaded_bytes / (1024**2)) - self._logged = True + logger.info( + "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2) + ) diff --git a/torch-ext/optimizer/muon.py b/torch-ext/optimizer/muon.py index af16b49d09c56a3c44ea7498ae5b1596494d9746..14c0e22471fa6d47a51ed95e0e0c341dc18d5194 100644 --- a/torch-ext/optimizer/muon.py +++ b/torch-ext/optimizer/muon.py @@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, @@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer): scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: - qk_clip(p, scales_full, qk_clip_state.head_dim) + qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return @@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer): def state_dict(self) -> dict: if self.cpu_offload: - raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.") + raise RuntimeError( + "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." + ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: - raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.") + raise RuntimeError( + "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." + ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that diff --git a/torch-ext/optimizer/newton_schulz.py b/torch-ext/optimizer/newton_schulz.py index 2b1a938d06acf1a40985bda013a9061a8d42e407..d939264b69a34e7a3fa78859f34dc265a1159d59 100644 --- a/torch-ext/optimizer/newton_schulz.py +++ b/torch-ext/optimizer/newton_schulz.py @@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000): E = inf for _ in range(max_iter): old_E = E - LHS = np.array([ - [l, l**3, l**5, 1], - [q, q**3, q**5, -1], - [r, r**3, r**5, 1], - [u, u**3, u**5, -1], - ]) + LHS = np.array( + [ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ] + ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): - raise ValueError(f"_optimal_quintic: non-finite solve result " - f"a={a}, b={b}, c={c}, E={E}") + raise ValueError( + f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" + ) q, r = np.sqrt( - (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / - (10 * c)) + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) + ) if not np.all(np.isfinite([q, r])): - raise ValueError( - f"_optimal_quintic: non-finite node update q={q}, r={r}") + raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( - f"_optimal_quintic: did not converge after {max_iter} iterations") + f"_optimal_quintic: did not converge after {max_iter} iterations" + ) return float(a), float(b), float(c) @@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. -_coeffs_list = _optimal_composition(l=1e-3, - num_iters=10, - safety_factor_eps=1e-2, - cushion=0.02) +_coeffs_list = _optimal_composition( + l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 +) # This code is adapted from: @@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps): X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations @@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps): X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( - repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + repeat(_coeffs_list[-1], steps - len(_coeffs_list)) + ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) diff --git a/torch-ext/optimizer/qk_clip.py b/torch-ext/optimizer/qk_clip.py index 9bd14b01bb8fa00e246ee34d2483616b4f3230ed..2aba711b3004b7f09e7141da7ef834bd61cc2430 100644 --- a/torch-ext/optimizer/qk_clip.py +++ b/torch-ext/optimizer/qk_clip.py @@ -13,7 +13,11 @@ logger = logging.getLogger(__name__) def parse_qk_layer(name: str) -> tuple[str | None, int]: """ Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + and return (kind, layer_index). + + Supported kinds: + MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj' + MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj) Returns: (kind, layer_idx) or (None, -1) if not matched. @@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.5.attn.wk.weight' -> ('wk', 5) 'model.2.attn.q_proj.weight' -> ('q_proj', 2) 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.1.attn.wq_b.weight' -> ('wq_b', 1) + 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0) 'model.4.attn.v_proj.weight' -> (None, -1) """ parts = normalize_fqn(name).split('.') @@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: layer_idx = int(part) break - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 @@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: @dataclass class QKClipInfo: """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None indices: list[int] # which heads to consider for clipping - head_dim: int # from config + head_dim: int # from config (qk_head_dim for MLA wq_b) threshold: float # from config logit: torch.Tensor | None + # MLA-specific fields + is_mla: bool = False + qk_nope_head_dim: int = 0 + qk_rope_head_dim: int = 0 + v_head_dim: int = 0 + def get_qk_clip_info(clip_config, n, qk_logits): """Extract QK clipping info for a named parameter. Args: clip_config: QK clipping configuration dict (or None). + MHA/GQA keys: head_dim, threshold, q_indices, k_indices + MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim n: Parameter name string. qk_logits: Dict mapping layer indices to logit tensors (or None). @@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits): head_dim = clip_config.get('head_dim') threshold = clip_config.get('threshold') kind, layer_idx = parse_qk_layer(n) + is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = clip_config.get(indices_key, []) or [] - if isinstance(logit, DTensor): # In TP settings, qk_logits may be DTensor # We convert it to full tensor here for simplicity logit = logit.full_tensor() - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) + if kind in ('wq_b', 'wq', 'q_proj'): + indices = clip_config.get('q_indices', []) or [] + elif kind in ('wkv_b', 'wk', 'k_proj'): + indices = clip_config.get('k_indices', []) or [] + + if is_mla: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + is_mla=True, + qk_nope_head_dim=clip_config['qk_nope_head_dim'], + qk_rope_head_dim=clip_config['qk_rope_head_dim'], + v_head_dim=clip_config['v_head_dim'], + ) + else: + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) def compute_scales(p, qk_clip_state): """Compute per-head scaling factors for QK clipping. - Returns scales tensor if any head exceeds threshold, else None. + Returns scales tensor (√γ per head) if any head exceeds threshold, else None. + For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim. """ kind = qk_clip_state.kind indices = qk_clip_state.indices @@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state): if not head_scales: return None - H_global = p.shape[0] // head_dim + # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows + if qk_clip_state.is_mla and kind == 'wkv_b': + effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim + else: + effective_head_dim = head_dim + + H_global = p.shape[0] // effective_head_dim scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale return scales_full -def qk_clip(p, scales, head_dim): - """Apply per-head scaling to a Q/K projection weight matrix.""" - if isinstance(p, torch.nn.Parameter): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - else: - W = p.view(-1, head_dim, p.shape[1]) - W.mul_(scales.view(-1, 1, 1)) +def qk_clip(p, scales, info): + """Apply per-head scaling to a Q/K projection weight matrix. + + Args: + p: Parameter (nn.Parameter or raw tensor). + scales: [n_heads] tensor, each element = √γ_h. + info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions. + + MLA sub-region scaling per Algorithm 1 (MuonClip): + wq_b: q_nope rows → √γ, q_pe rows → γ + wkv_b: k_nope rows → √γ, v rows → unchanged + """ + W = p.data if isinstance(p, torch.nn.Parameter) else p + + if not info.is_mla: + # MHA/GQA: uniform √γ applied to all rows in each head + W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) + return + + # MLA: vectorized sub-region scaling within each head + if info.kind == 'wq_b': + qk_nope = info.qk_nope_head_dim + qk_head_dim = qk_nope + info.qk_rope_head_dim + W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope → √γ + W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1, + 1)) # q_pe → γ + + elif info.kind == 'wkv_b': + qk_nope = info.qk_nope_head_dim + kv_stride = qk_nope + info.v_head_dim + W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim] + W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope → √γ + # v rows: not touched (k_R shared rotary unchanged)