diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..41b9d48cd73b8286ecfcb69084bb4ab50c9cefc2 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90ace47a61519aefe759810c803789e7f91e6949ca0b04fc177e311709976334 +size 1940944 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 96a6868d0ec423b37d2097f2a60061a3b90efc70..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f095be87ff6185010a3cff4175abbde0b2e50fe1e435dc1db4eaf5bf1f6199ca -size 1940944 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/adamw.py b/build/torch210-cxx11-cu126-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/adamw.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/core.py b/build/torch210-cxx11-cu126-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/core.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py b/build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch210-cxx11-cu126-x86_64-linux/muon.py b/build/torch210-cxx11-cu126-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/muon.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch210-cxx11-cu126-x86_64-linux/pipeline.py b/build/torch210-cxx11-cu126-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/pipeline.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py b/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e1c43b7d051c67005c5fa125813c7c5c004a8702 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1abfa69cd254e0000246a074c0bfa53c2e72bb53cc5fa8216275295cd021c57a +size 2004144 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 1d1806041a1fdcea027e6aa31eb8b774c6c797d0..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4919c48c77c6223dbf668f1461bcec175ef1bd6ea4cec8c2509de12ca7200a62 -size 2004144 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/adamw.py b/build/torch210-cxx11-cu128-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/adamw.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/core.py b/build/torch210-cxx11-cu128-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/core.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py b/build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch210-cxx11-cu128-x86_64-linux/muon.py b/build/torch210-cxx11-cu128-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/muon.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch210-cxx11-cu128-x86_64-linux/pipeline.py b/build/torch210-cxx11-cu128-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/pipeline.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py b/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7d604555355a392e1b7562d230d7170798eb9de4 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6869cfabdf45c7092d251846b3099287f8bccd5c5ebe7edf1a5fd21436324349 +size 2004728 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 08caf42e7e7b1f311490df8058ed06d87ea79358..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b9c7bb12bc030d4959e880a959b39ea07eb03e16175d7cf03829f9860f52525d -size 2004728 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/adamw.py b/build/torch210-cxx11-cu130-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/adamw.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/core.py b/build/torch210-cxx11-cu130-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/core.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py b/build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch210-cxx11-cu130-x86_64-linux/muon.py b/build/torch210-cxx11-cu130-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/muon.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch210-cxx11-cu130-x86_64-linux/pipeline.py b/build/torch210-cxx11-cu130-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/pipeline.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py b/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..38dec6245ac97ee9f79d91398d5c02fe135c3520 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0102e10121a43f6d5d59a23f2c0e21d88469cc4597d84f7d48b64b0fabfeacdb +size 1866400 diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 49889967591405cc5266af4e0911e0895d7b309b..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:00e9d9e1c2306badb97c3b8f2454a47d6335a302101a38c804ad3c7b075168cc -size 1866400 diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/adamw.py b/build/torch210-cxx11-rocm70-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/adamw.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/core.py b/build/torch210-cxx11-rocm70-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/core.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py b/build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/muon.py b/build/torch210-cxx11-rocm70-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/muon.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py b/build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py b/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..14df6e77a5767da49bc8be17e6da245a59d901ae --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f41709878a4def27b12f4f9a4f5b767027fb33141e775f64ad04d434fcbe33d9 +size 1866112 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 10d8f0e7de3adaf54aa7478421c25a02e409544e..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e67022789ddd9296552fc5ab4075ce96b8b00b75bce057c707e5b5076bbde734 -size 1866112 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/adamw.py b/build/torch210-cxx11-rocm71-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/adamw.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/core.py b/build/torch210-cxx11-rocm71-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/core.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py b/build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py b/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/muon.py b/build/torch210-cxx11-rocm71-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/muon.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py b/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/pipeline.py b/build/torch210-cxx11-rocm71-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/pipeline.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py b/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..0a153517cd068f531e1151521831828506324300 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08e3ee2f567d7a89ba34a82429c2f47cdb17d53ef66dc7d5751cabeafa01ce0f +size 1936664 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 1ccf0dbda4220efff722d4b971b23b40592c3a81..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ee0ac60d2f40d1feb67e804e6b1024844d8cbbf5c62d6d014621a40dc6b3afc3 -size 1936664 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/adamw.py b/build/torch28-cxx11-cu126-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/adamw.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/core.py b/build/torch28-cxx11-cu126-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/core.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py b/build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch28-cxx11-cu126-x86_64-linux/muon.py b/build/torch28-cxx11-cu126-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/muon.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch28-cxx11-cu126-x86_64-linux/pipeline.py b/build/torch28-cxx11-cu126-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/pipeline.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py b/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..79e2f5028c3b6f741b5ed831abe6600ef624d197 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e08baad750646c67f23c6e7c4d0e1b7266eeffed3bbb730729ba8f37e120a2b1 +size 1999872 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 2a7b540994e8d72dfccead970e2fe685f976d2ae..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:89137de30694bc0ad3165d1a998c801151370290ed1217f343409b11a8f74908 -size 1999872 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/adamw.py b/build/torch28-cxx11-cu128-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/adamw.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/core.py b/build/torch28-cxx11-cu128-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/core.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py b/build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch28-cxx11-cu128-x86_64-linux/muon.py b/build/torch28-cxx11-cu128-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/muon.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch28-cxx11-cu128-x86_64-linux/pipeline.py b/build/torch28-cxx11-cu128-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/pipeline.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py b/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..c4c5dcd3fbdd8d04417b674b11a0b04a80fa892b --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c485caa290f4b43e49db4ceafe25f0d0840dcdd61d02a5aecfa78d8f9acc9b60 +size 1999872 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 9e281900c03ffb5f3513aa19cc4f0f48e8a90cae..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0f5d04e35a6d7a64d44ba42590c3ef930535c6100498d9c4bc28deb50c819a8d -size 1999872 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/adamw.py b/build/torch28-cxx11-cu129-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/adamw.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/core.py b/build/torch28-cxx11-cu129-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/core.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py b/build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch28-cxx11-cu129-x86_64-linux/muon.py b/build/torch28-cxx11-cu129-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/muon.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch28-cxx11-cu129-x86_64-linux/pipeline.py b/build/torch28-cxx11-cu129-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/pipeline.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py b/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..8175f4e66cd33f1ee174ff000b00179957cb61b1 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a57395ef49976af61778f127cfdeace6a4c35b491b9903e48b1cd7199ee217c +size 1865080 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 885ac14b4c5469770fdeaf3766d4c28aa25ada8a..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a3fcf69ab6e1e6d7732b6b887350af98666ada6909773898d6b2c8efa56c4cd0 -size 1865080 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/adamw.py b/build/torch28-cxx11-rocm63-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/adamw.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/core.py b/build/torch28-cxx11-rocm63-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/core.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/cpu_offload.py b/build/torch28-cxx11-rocm63-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/muon.py b/build/torch28-cxx11-rocm63-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/muon.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/pipeline.py b/build/torch28-cxx11-rocm63-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/pipeline.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py b/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a2425d7f317cc206851e06bdeb2dd68df4828eb1 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd57f2197a2107ad920abbce3e2c986b79c76cb864f693f53bd389b26b763902 +size 1865168 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 6ec327ad391829e41a0a5dc05568e90ac77781b0..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dc94ac631623c7169f42b8c21066b4cf03ef892078269fe0c4318634b9c08912 -size 1865168 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/adamw.py b/build/torch28-cxx11-rocm64-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/adamw.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/core.py b/build/torch28-cxx11-rocm64-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/core.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/cpu_offload.py b/build/torch28-cxx11-rocm64-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py b/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/muon.py b/build/torch28-cxx11-rocm64-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/muon.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py b/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/pipeline.py b/build/torch28-cxx11-rocm64-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/pipeline.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py b/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..fbee4a83abf477b0b81cb9d30812e4d9795bb5f1 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6540559a09b80b976f3ece411907bf0ef78ebf48c490f4b3372c01a05a87bde4 +size 1936664 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 44ca420ee062544acac81ece75a66953807a4502..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fb807f26eac961830776950d2bad9ef96838705fcdf5be8c5ee6dc9c18e0c3a4 -size 1936664 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/adamw.py b/build/torch29-cxx11-cu126-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/adamw.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/core.py b/build/torch29-cxx11-cu126-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/core.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch29-cxx11-cu126-x86_64-linux/cpu_offload.py b/build/torch29-cxx11-cu126-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch29-cxx11-cu126-x86_64-linux/muon.py b/build/torch29-cxx11-cu126-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/muon.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch29-cxx11-cu126-x86_64-linux/pipeline.py b/build/torch29-cxx11-cu126-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/pipeline.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py b/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..258bb22535eae6a86ddb0228a24fd6199faa8c18 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca76b214cb1ffdc5d25fe0d76f36b302c5127726d3e04f4e7f0cf01920192250 +size 1999872 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 81304f1e72844f803f036498f2b7bad16a5d60c1..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5729d5d70fb41aa7eb7ae7fa095c6f6765a0119ac70d0c3139fc31357f4abe78 -size 1999872 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/adamw.py b/build/torch29-cxx11-cu128-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/adamw.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/core.py b/build/torch29-cxx11-cu128-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/core.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch29-cxx11-cu128-x86_64-linux/cpu_offload.py b/build/torch29-cxx11-cu128-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch29-cxx11-cu128-x86_64-linux/muon.py b/build/torch29-cxx11-cu128-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/muon.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch29-cxx11-cu128-x86_64-linux/pipeline.py b/build/torch29-cxx11-cu128-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/pipeline.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py b/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d8cfcfaba28fb17d76d8d9fc1f63aba601aa0659 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4cab0b30edb25c98cb71873856093570322b4d10ac2f25b9949c1e615c3ae709 +size 2000456 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index cad267f9451b926dc53837595c5ec843476dc560..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:09573102bbde35675944ee02dacd2bbad50fc6f151816a6814ef5651adf40e69 -size 2000456 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/adamw.py b/build/torch29-cxx11-cu130-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/adamw.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/core.py b/build/torch29-cxx11-cu130-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/core.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch29-cxx11-cu130-x86_64-linux/cpu_offload.py b/build/torch29-cxx11-cu130-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch29-cxx11-cu130-x86_64-linux/muon.py b/build/torch29-cxx11-cu130-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/muon.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch29-cxx11-cu130-x86_64-linux/pipeline.py b/build/torch29-cxx11-cu130-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/pipeline.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py b/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..5151a4ccef3f6f86a1f0959fc1aa14c9783a89c4 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bebe62763a8ae7134def4c928029cbe350113fcdc4e265d4a78b11e7d8d02bef +size 1865112 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index 75edac5d5f08066ba4f74df9fb2c6b740d65e613..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:40563a27767823176595fede23009b17b26e6b2c6a5847e255448d51da70b854 -size 1865112 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/adamw.py b/build/torch29-cxx11-rocm63-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/adamw.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/core.py b/build/torch29-cxx11-rocm63-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/core.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/cpu_offload.py b/build/torch29-cxx11-rocm63-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/muon.py b/build/torch29-cxx11-rocm63-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/muon.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/pipeline.py b/build/torch29-cxx11-rocm63-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/pipeline.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py b/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py index b34ab4955d83942fd070363fe79547a36deb1742..4a298dcaadca852ceae58fff62adbebb27c99394 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _optimizer_7aef62f_dirty -ops = torch.ops._optimizer_7aef62f_dirty +from . import _optimizer_5b58933_dirty +ops = torch.ops._optimizer_5b58933_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_optimizer_7aef62f_dirty::{op_name}" \ No newline at end of file + return f"_optimizer_5b58933_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_5b58933_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_5b58933_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..5d9478390830c11d4c370c56732e51785ec9a5d2 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_5b58933_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11694657e111a565af0c1229bfcdbcf9ac47246f5da63270897d0f88f4ef83da +size 1865232 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so deleted file mode 100755 index db1924b3f25a792a5aa5de6db2005cac974da79a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_7aef62f_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ae0556a81551f05fff0b83b1924c55a70e399c29171f9f7ce1bd63ccb24fc417 -size 1865232 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/adamw.py b/build/torch29-cxx11-rocm64-x86_64-linux/adamw.py index a6125200cc3da0996f0f3344131a7c6de4ac5863..b5a95816a9f5b9e1889eaadae65373bfbced809a 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/adamw.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/adamw.py @@ -1,8 +1,12 @@ +import logging from collections import defaultdict from typing import cast import torch from torch.distributed.tensor import DTensor +from torch.profiler import record_function + +logger = logging.getLogger(__name__) def fused_adamw( @@ -72,54 +76,72 @@ def fused_adamw( ) -def step_adamw_params(optimizer_state, params, group): - """Run fused AdamW on a list of parameters sharing the same placement. +def _to_local(t): + """Unwrap DTensor to local tensor for fused ops.""" + return t._local_tensor if isinstance(t, DTensor) else t - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - params: List of parameters to update. - group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. - """ + +# --------------------------------------------------------------------------- +# Caches for eliminating per-step Python overhead. +# +# Placement grouping and tensor list assembly are identical every step +# (params don't change placement, moment/step tensors are the same objects +# after initialisation). We cache them keyed by id() of the param list +# stored in param_groups (stable across steps). +# +# Only gradients change each step and must be collected fresh. +# --------------------------------------------------------------------------- + +# id(group["params"]) → dict[placement_key, list[param]] +_placement_cache: dict[int, dict[tuple, list]] = {} + +# id(placement_group_list) → (params_local, moment1, moment2, state_steps) +_tensor_cache: dict[int, tuple[list, list, list, list]] = {} + + +def _step_adamw_params_slow(optimizer_state, params, group): + """Uncached fallback for the rare case where some params lack grads.""" params_with_grads = [] grads = [] moment1 = [] moment2 = [] - max_exp_avg_sqs = [] state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] for p in params: g = p.grad if g is None: continue state = optimizer_state[p] - params_with_grads.append(p) - grads.append(g) + params_with_grads.append(_to_local(p)) + grads.append(_to_local(g)) if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) state["moment1"] = torch.zeros_like(g) state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + if not params_with_grads: + return + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] fused_adamw( params_with_grads, grads, moment1, moment2, - max_exp_avg_sqs, + [], state_steps, amsgrad=False, beta1=beta1, @@ -131,24 +153,119 @@ def step_adamw_params(optimizer_state, params, group): ) +def step_adamw_params(optimizer_state, params, group): + """Run fused AdamW on a list of parameters sharing the same placement. + + After the first call, cached tensor lists (params_local, moment1, + moment2, state_steps) are reused — only gradients are collected fresh. + + Args: + optimizer_state: The optimizer's state dict (self.state in Muon). + params: List of parameters to update. + group: Parameter group dict with lr, adamw_betas, adamw_eps, weight_decay. + """ + # Collect grads — the only thing that changes each step. + with record_function("adamw::collect_grads"): + grads = [] + for p in params: + g = p.grad + if g is None: + # Rare: fall back to slow path that filters per-param. + _step_adamw_params_slow(optimizer_state, params, group) + return + grads.append(_to_local(g)) + + tensor_key = id(params) + if tensor_key not in _tensor_cache: + with record_function("adamw::init_tensor_cache"): + params_local = [] + moment1 = [] + moment2 = [] + state_steps = [] + + for p in params: + state = optimizer_state[p] + params_local.append(_to_local(p)) + if "step" not in state: + state["step"] = torch.zeros((), + dtype=torch.float32, + device=p.device) + state["moment1"] = torch.zeros_like(p.grad) + state["moment2"] = torch.zeros_like(p.grad) + moment1.append(_to_local(state["moment1"])) + moment2.append(_to_local(state["moment2"])) + if not isinstance(state["step"], torch.Tensor): + state["step"] = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + state_steps.append(state["step"]) + + _tensor_cache[tensor_key] = (params_local, moment1, moment2, + state_steps) + + params_local, moment1, moment2, state_steps = _tensor_cache[tensor_key] + + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + with record_function("adamw::fused_adamw"): + fused_adamw( + params_local, + grads, + moment1, + moment2, + [], + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def step_adamw(optimizer_state, group): """Dispatch AdamW step, grouping parameters by type and placement. + Placement grouping is cached after the first call since params never + change their placement between steps. + Args: optimizer_state: The optimizer's state dict (self.state in Muon). group: Parameter group dict. """ params = group["params"] + placement_key = id(params) - # group params with its type and placement - placement_to_params: dict[tuple, list[torch.Tensor]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for group_params in placement_to_params.values(): + if placement_key not in _placement_cache: + with record_function("adamw::group_by_placement"): + placement_to_params: dict[tuple, + list[torch.Tensor]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + logger.debug( + "[AdamW] DTensor param: shape=%s, placements=%s, " + "mesh=%s, grad=%s", p.shape, p.placements, + p.device_mesh.mesh_dim_names, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple( + [p.placements, p.device_mesh])].append(p) + case torch.Tensor(): + logger.debug( + "[AdamW] plain param: shape=%s, grad=%s", p.shape, + p.grad.shape if p.grad is not None else None) + placement_to_params[tuple([torch.Tensor, + None])].append(p) + + logger.debug("[AdamW] %d placement groups, %d total params", + len(placement_to_params), len(params)) + + _placement_cache[placement_key] = dict(placement_to_params) + + for group_params in _placement_cache[placement_key].values(): step_adamw_params(optimizer_state, group_params, group) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/core.py b/build/torch29-cxx11-rocm64-x86_64-linux/core.py index 8d47ffe8e68b523ff226f7bb1ca2bc059ee3b409..c69d515afef305ad0ed66374095fa2d2468d99cc 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/core.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/core.py @@ -1,11 +1,25 @@ +import logging import math from dataclasses import dataclass +from typing import List import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor +# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into +# parameter FQNs. Activation checkpointing similarly inserts +# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, +# expert_keys, QK layer parsing) works regardless of wrapper nesting. +_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) + +logger = logging.getLogger(__name__) + + +def normalize_fqn(name: str) -> str: + """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" + return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) + @dataclass class _muon_state: @@ -17,26 +31,71 @@ class _muon_state: qk_clip_state: torch.Tensor | None = None -def update_g(optimizer_state, p, g, group, momentum): - """Apply momentum update to gradient. +def _batch_momentum( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update (no nesterov).""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) - Args: - optimizer_state: The optimizer's state dict (self.state in Muon). - p: Parameter tensor. - g: Gradient tensor. - group: Parameter group dict. - momentum: Momentum coefficient. - Returns: - Momentum-updated gradient tensor. +def _batch_momentum_nesterov( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, +) -> None: + """Batched momentum update with nesterov correction.""" + torch._foreach_mul_(momentum_bufs, momentum) + torch._foreach_add_(momentum_bufs, grads) + nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) + torch._foreach_add_(grads, nesterov_terms) + + +_compiled_momentum: dict[bool, callable] = {} +_use_momentum_compile = True + + +def set_momentum_compile(enabled: bool): + """Toggle torch.compile for batched momentum.""" + global _use_momentum_compile + _use_momentum_compile = enabled + + +def batch_pre_ortho( + grads: List[torch.Tensor], + momentum_bufs: List[torch.Tensor], + momentum: torch.Tensor, + nesterov: bool, +) -> None: + """Batched momentum update on lists of plain tensors. + + Mirrors dion's ``muon_update_pre_orthogonalize``. + Inputs must be plain CUDA tensors (not DTensor). + Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. + + When compile is enabled, uses separately compiled functions for + nesterov=True/False to avoid graph breaks from the branch. """ - state = optimizer_state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf + fn = _batch_momentum_nesterov if nesterov else _batch_momentum + if _use_momentum_compile: + if nesterov not in _compiled_momentum: + _compiled_momentum[nesterov] = torch.compile(fn) + fn = _compiled_momentum[nesterov] + fn(grads, momentum_bufs, momentum) + + +def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): + """Weight-decay + update on plain tensors. + + Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache + lookup per call × 256+ params = massive overhead. The pipeline path uses + batched _foreach_* ops instead; this function remains for base() and + distributed_muon(). + """ + p_data.mul_(1 - lr * weight_decay) + p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): @@ -49,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay): adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ - if isinstance(p, torch.nn.Parameter): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - else: - p.mul_(1 - lr * weight_decay) - p.add_(u, alpha=-adjusted_lr) + # Unwrap Parameter -> underlying data tensor. + p_data = p.data if isinstance(p, torch.nn.Parameter) else p + # Unwrap DTensor -> local CUDA tensor for compiled kernel. + if isinstance(p_data, DTensor): + p_data = p_data._local_tensor + u_data = u._local_tensor if isinstance(u, DTensor) else u + _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): @@ -77,14 +135,55 @@ def adjust_lr_for_muon(lr, param_shape): return adjusted_lr +def _match_key(parts, key): + """Check if key matches as contiguous components in parts. + + Single-component keys (e.g. "experts") match any single component. + Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. + """ + key_parts = key.split(".") + key_len = len(key_parts) + if key_len == 1: + return key in parts + return any(parts[i:i + key_len] == key_parts + for i in range(len(parts) - key_len + 1)) + + +def is_expert_param(name, expert_keys): + """Check if a parameter name matches any expert key (component-level).""" + if not expert_keys: + return False + parts = normalize_fqn(name).split(".") + return any(_match_key(parts, key) for key in expert_keys) + + def default_is_muon(name, x, expert_keys=None): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - if any(key in name for key in skip_keys): + normalized = normalize_fqn(name) + parts = normalized.split(".") + skip_keys = [ + "embed_tokens", + "lm_head", + "tok_embeddings", + "output", + "mhc_attn", + "mhc_ffn", + "lambda_proj", + ] + if any(key in parts for key in skip_keys): + logger.info( + "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", + normalized, name, x.ndim) return False effective_ndim = x.ndim - if expert_keys and any(key in name for key in expert_keys): + is_expert = is_expert_param(name, expert_keys) + if is_expert: effective_ndim -= 1 - return effective_ndim >= 2 + result = effective_ndim >= 2 + logger.info( + "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", + normalized, name, x.ndim, is_expert, effective_ndim, + "Muon" if result else "AdamW") + return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): @@ -92,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] - non_muon_params = [] + non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: @@ -102,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): muon_names.append(n) else: non_muon_params.append(p) + non_muon_names.append(n) + + logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", + expert_keys, len(muon_names), len(non_muon_names)) return [ { diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/cpu_offload.py b/build/torch29-cxx11-rocm64-x86_64-linux/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..58840a02b3f589f7922e2779241d13a82494da8c --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/cpu_offload.py @@ -0,0 +1,188 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream( + self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_( + local.reshape(-1), non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_( + cpu_flat[off:off + n], non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py index 75e2e1e8d66975fc9aea75d994de288216a5e9a4..890ebab62fa07474c71bfae393e3b168a1c69d7d 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py @@ -72,12 +72,6 @@ def get_slices_of_dtensor( else: curr_size = target.size()[shard_dim] - if curr_size % num_chunks != 0: - raise NotImplementedError( - f"Dimension size {curr_size} is not divisible " - f"by number of ranks {num_chunks} for shard " - f"placement on dim {shard_dim}. (shape: {target.shape})") - # Compute indices for this level of sharding if isinstance(placement, _StridedShard): _shard_size, offsets = _StridedShard.local_shard_size_and_offset( diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py index 95414c6dcd6ec6cd52bf7aebafa260871aff27aa..792de23d82c3fb45fe33d397ab9b76a0787259d0 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py @@ -43,6 +43,7 @@ def get_autotune_config(): @triton.autotune( configs=get_autotune_config(), key=['M', 'K'], + restore_value=['y'], ) @triton.jit def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, @@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) -def matmul_transpose_assign(d_in, d_out): - assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" - assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" - assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" - assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" - assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" - assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" - assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ - "First dimension of `d_in` must match first and second dimension of `d_out`" - +@torch.library.custom_op("muon::matmul_transpose_assign", + mutates_args=("d_out", )) +def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """Compute d_out = d_in @ d_in.T using an optimized Triton kernel.""" d_in = d_in.contiguous() M, K = d_in.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( @@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out): with torch.cuda.device(d_in.device.index): mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), d_out.stride(0), d_out.stride(1)) + + +@matmul_transpose_assign.register_fake +def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None: + """FakeTensor impl: d_out is already allocated, mutation is declared.""" + pass diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/muon.py b/build/torch29-cxx11-rocm64-x86_64-linux/muon.py index 1195ca7bf4c2b594b5459ec114b8a8f2e530ad66..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/muon.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/muon.py @@ -10,13 +10,16 @@ from torch.profiler import record_function from .adamw import step_adamw from .async_utils import run_pipeline -from .core import (_muon_state, adjust_lr_for_muon, - get_default_muon_param_groups, update_g, update_p) +from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, + get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, - _zeropower_via_newtonschulz5) -from .pipeline import muon_chunk_pipeline + _zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5, + zeropower_via_newtonschulz5_batched) +from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) @@ -45,9 +48,21 @@ def _expand_expert_params(names, params, expert_keys): expanded_params = [] for n, p in zip(names, params): - is_expert = expert_keys and any(key in n for key in expert_keys) + is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) + if is_expert: + if is_dtensor: + logger.debug( + "[expand_expert] %s: expert DTensor, shape=%s, " + "placements=%s, mesh=%s, local_shape=%s", n, p.shape, + p.placements, p.device_mesh.mesh_dim_names, + p.to_local().shape) + else: + logger.debug( + "[expand_expert] %s: expert plain tensor, shape=%s", n, + p.data.shape) + if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " @@ -168,7 +183,6 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. - small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded @@ -193,8 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - small_param_numel_threshold=65536, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -228,8 +242,12 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon - self.small_param_numel_threshold = small_param_numel_threshold self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False + self._parallel_cache: dict[tuple[str, ...], dict] = {} + self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -333,8 +351,8 @@ class Muon(torch.optim.Optimizer): if g is None: continue - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) @@ -355,52 +373,269 @@ class Muon(torch.optim.Optimizer): weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): - """ Implementation of Distributed Muon by Liu et al. """ + """Batched Distributed Muon — for testing/correctness verification only. - # Momentum is already applied by _step_muon before this method. - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - - # Gather G - if isinstance(p.data, DTensor): - g_full = g.full_tensor() - p_full = p.data.full_tensor() - else: - g_full = g - p_full = p - - u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = adjust_lr_for_muon(lr, p_full.shape) - update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + Uses all-gather to reconstruct full tensors, computes Newton-Schulz on + the full grad, then slices back to local shards. This is simpler but + slower than the parallel pipeline (all2all) path, so it serves as a + reference implementation for verifying correctness. + """ + with record_function("distributed_muon"): + # Momentum is already applied by _step_muon before this method. + ns_steps = group["ns_steps"] - qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) + # Separate plain tensors (no communication) from DTensors. + plain_names, plain_params = [], [] + dtensor_names, dtensor_params = [], [] + for n, p in zip(names, params): + if p.grad is None: + continue + if isinstance(p.data, DTensor): + dtensor_names.append(n) + dtensor_params.append(p) + else: + plain_names.append(n) + plain_params.append(p) + + # Process plain tensors per-param (no communication). + for n, p in zip(plain_names, plain_params): + u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), + steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = get_qk_clip_info(self.clip_config, n, + qk_logits) + scales_full = compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + qk_clip(p, scales_full, qk_clip_state.head_dim) + + if not dtensor_params: + return + + # Group DTensors by (placements, mesh) for batched all-gather. + placement_groups: dict[tuple, + tuple[list, + list]] = defaultdict(lambda: ([], [])) + for n, p in zip(dtensor_names, dtensor_params): + key = (p.placements, p.device_mesh) + placement_groups[key][0].append(n) + placement_groups[key][1].append(p) + + logger.info( + "distributed_muon: %d placement groups, %d total dtensors", + len(placement_groups), len(dtensor_params)) + + for (placements, mesh), (grp_names, + grp_params) in placement_groups.items(): + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + placements, mesh) + rank = dist.get_rank(shard_pg) + world_size = dist.get_world_size(shard_pg) + + logger.info(" group: %d params, placements=%s, world_size=%d", + len(grp_params), placements, world_size) + + # Separate params that can be batched (all shard dims evenly + # divisible) from those needing per-param full_tensor + # (e.g. MoE gate weights with fewer rows than shard ranks). + # all_gather_into_tensor requires equal buffer sizes across + # ranks, so uneven splits must use DTensor full_tensor(). + batch_names, batch_params = [], [] + single_names, single_params = [], [] + for n, p in zip(grp_names, grp_params): + even = all(p.shape[pl.dim] % + shard_mesh.mesh.shape[dim_idx] == 0 + for dim_idx, pl in enumerate(shard_placements)) + if even: + batch_names.append(n) + batch_params.append(p) + else: + single_names.append(n) + single_params.append(p) + + # Process uneven-split params per-param via full_tensor(). + for n, p in zip(single_names, single_params): + with record_function("distributed_muon::newton_schulz"): + g_full = p.grad.full_tensor().to(COMM_DTYPE) + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + if not batch_params: + continue - scales_full = compute_scales( - p_full, qk_clip_state) if qk_clip_state is not None else None + logger.info(" batched=%d, single=%d", len(batch_params), + len(single_params)) + + # Concat all local grad shards into a single flat buffer. + with record_function("distributed_muon::gather"): + grad_locals = [ + p.grad.to_local().to(COMM_DTYPE).flatten() + for p in batch_params + ] + numels = [g.numel() for g in grad_locals] + grad_concat = torch.cat(grad_locals) + del grad_locals + + # Single all-gather (replaces N separate full_tensor). + grad_gathered = torch.empty( + grad_concat.numel() * world_size, + dtype=COMM_DTYPE, + device="cuda", + ) + dist.all_gather_into_tensor(grad_gathered, + grad_concat, + group=shard_pg) + + total_numel = grad_concat.numel() + del grad_concat + + # Precompute per-param offsets within the concat buffer. + offsets = [] + off = 0 + for ne in numels: + offsets.append(off) + off += ne + + # Per-param: reconstruct full grad → NS → local update. + for i, (n, p) in enumerate(zip(batch_names, batch_params)): + with record_function("distributed_muon::newton_schulz"): + g_full = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + for r in range(world_size): + r_start = r * total_numel + offsets[i] + shard = grad_gathered[r_start:r_start + numels[i]] + indices = get_slices_of_dtensor( + p, r, shard_mesh, shard_placements) + g_full[indices] = shard.reshape( + g_full[indices].shape) + + u_full = _zeropower_via_newtonschulz5(g_full, + steps=ns_steps) + del g_full + + with record_function("distributed_muon::update"): + adjusted_lr = adjust_lr_for_muon(lr, p.shape) + p._local_tensor.mul_(1 - lr * weight_decay) + local_indices = get_slices_of_dtensor( + p, rank, shard_mesh, shard_placements) + u_local = u_full[local_indices] + p._local_tensor.add_(u_local, alpha=-adjusted_lr) + del u_full + + qk_clip_state = get_qk_clip_info( + self.clip_config, n, qk_logits) + scales_full = compute_scales( + p, qk_clip_state + ) if qk_clip_state is not None else None + if scales_full is not None: + ratio = p.shape[0] // scales_full.shape[0] + idx0 = local_indices[0] + if isinstance(idx0, slice): + start = idx0.start or 0 + idx0 = torch.arange(start, + idx0.stop, + device=scales_full.device) + row_scales = scales_full[idx0 // ratio] + p._local_tensor.mul_(row_scales.view(-1, 1)) + + def _setup_parallel(self, names, params, group, qk_logits): + """Compute (or retrieve cached) parallel pipeline metadata. + + Returns: + (ordered_params, param_to_state, rank, chunk_size) + """ + cache_key = tuple(names) - if scales_full is not None: - qk_clip(p_full, scales_full, qk_clip_state.head_dim) + if cache_key not in self._parallel_cache: + # First call: compute metadata and populate cache. + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) - if isinstance(p.data, DTensor): - ndims = len(p.device_mesh.mesh.shape) - p_replicate = DTensor.from_local( - p_full, - device_mesh=p.device_mesh, - placements=[Replicate() for _ in range(ndims)], - ) + shard_pg = param_to_state[id(ordered_params[0])].process_group + rank = dist.get_rank(group=shard_pg) - p_sharded = p_replicate.redistribute( - device_mesh=p.device_mesh, - placements=p.placements, + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(shard_pg) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError( + "chunk_size must be -1 or a positive integer.") + + ordered_names = [ + param_to_state[id(p)].name for p in ordered_params + ] + name_to_state = { + param_to_state[id(p)].name: param_to_state[id(p)] + for p in ordered_params + } + self._parallel_cache[cache_key] = { + 'ordered_names': ordered_names, + 'name_to_state': name_to_state, + 'rank': rank, + 'chunk_size': chunk_size, + } + else: + # Cached path: rebuild param_to_state with current id(p) keys. + cache = self._parallel_cache[cache_key] + rank = cache['rank'] + chunk_size = cache['chunk_size'] + + name_to_param = dict(zip(names, params)) + ordered_params = [name_to_param[n] for n in cache['ordered_names']] + + param_to_state = {} + for p, n in zip(ordered_params, cache['ordered_names']): + cached_state = cache['name_to_state'][n] + param_to_state[id(p)] = _muon_state( + worker_rank=cached_state.worker_rank, + process_group=cached_state.process_group, + rank_indices=cached_state.rank_indices, + rank_numels=cached_state.rank_numels, + name=n, + qk_clip_state=get_qk_clip_info(self.clip_config, n, + qk_logits), ) - p.copy_(p_sharded) + return ordered_params, param_to_state, rank, chunk_size - def parallel(self, names, params, group, lr, weight_decay, qk_logits): + def parallel(self, + names, + params, + group, + lr, + weight_decay, + qk_logits, + prelaunch_gather=None): """ Perform a parallel optimization step using Muon. @@ -409,31 +644,23 @@ class Muon(torch.optim.Optimizer): interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). + + If ``prelaunch_gather`` is provided, it is passed to the first + chunk's generator to skip re-launching the already in-flight + A2A gather. """ # Momentum is already applied by _step_muon before this method. - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - # Compute local rank for this group's shard process group. - shard_pg = param_to_state[id(ordered_params[0])].process_group - rank = dist.get_rank(group=shard_pg) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - ordered_params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") + ordered_params, param_to_state, rank, chunk_size = ( + self._setup_parallel(names, params, group, qk_logits)) def pipelines(): + first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: - yield muon_chunk_pipeline( + kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, @@ -442,9 +669,11 @@ class Muon(torch.optim.Optimizer): weight_decay=weight_decay, none_grad=group["none_grad"], ) + if first and prelaunch_gather is not None: + kwargs['prelaunch_gather'] = prelaunch_gather + first = False + yield muon_chunk_pipeline(**kwargs) - with record_function("muon::barrier"): - dist.barrier() with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) @@ -456,16 +685,152 @@ class Muon(torch.optim.Optimizer): names = group["names"] # Apply momentum to all params before routing/expansion. + # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): - for n, p in zip(names, params): - g = p.grad - if g is None: + active_params = [p for p in params if p.grad is not None] + if active_params: + # Ensure momentum buffers exist (avoid zeros_like when already present). + for p in active_params: + if "momentum_buffer" not in self.state[p]: + self.state[p]["momentum_buffer"] = torch.zeros_like( + p.grad) + + # Extract local tensors for compiled batch function. + local_grads = [ + p.grad._local_tensor + if isinstance(p.grad, DTensor) else p.grad + for p in active_params + ] + local_bufs = [ + self.state[p]["momentum_buffer"]._local_tensor + if isinstance(self.state[p]["momentum_buffer"], DTensor) + else self.state[p]["momentum_buffer"] + for p in active_params + ] + + # Wrap momentum as tensor for torch.compile. + batch_pre_ortho(local_grads, local_bufs, + torch.tensor(momentum), group["nesterov"]) + + # For non-nesterov, the result is the momentum buffer. + if not group["nesterov"]: + for p in active_params: + p.grad = self.state[p]["momentum_buffer"] + + # Identify batched experts for deferred NS. + # Detection is cheap (condition checks only); actual NS compute is + # deferred so it can overlap with the first chunk's A2A gather. + deferred_expert_work = [] + if self.expert_keys: + batched_expert_indices = [] + for i, (n, p) in enumerate(zip(names, params)): + if not (is_expert_param(n, self.expert_keys) + and p.grad is not None): continue - g = update_g(self.state, p, g, group, momentum) - p.grad = g + # Eligible: plain tensor, or DTensor with no non-dim-0 shards. + if isinstance(p.data, DTensor): + has_tp = any( + _is_shard(pl) and pl.dim != 0 for pl in p.placements) + if has_tp: + continue + batched_expert_indices.append(i) + + if batched_expert_indices: + # Save refs for deferred NS; free grads from param list. + for i in batched_expert_indices: + p = params[i] + g = p.grad + local_g = (g._local_tensor + if isinstance(g, DTensor) else g) + local_data = (p.data._local_tensor if isinstance( + p.data, DTensor) else p.data) + deferred_expert_work.append((local_data, local_g)) + p.grad = None + + # Remove batched experts from lists before expansion. + keep = sorted( + set(range(len(params))) - set(batched_expert_indices)) + names = [names[i] for i in keep] + params = [params[i] for i in keep] + + def _run_deferred_expert_ns(): + """Execute deferred batched expert NS.""" + if not deferred_expert_work: + return + with record_function("muon::batched_expert_ns"): + ns_steps = group["ns_steps"] + for local_data, local_g in deferred_expert_work: + u = zeropower_via_newtonschulz5_batched( + local_g.to(COMM_DTYPE), steps=ns_steps) + adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) + local_data.mul_(1 - lr * weight_decay) + local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. - names, params = _expand_expert_params(names, params, self.expert_keys) + logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", + len(params), self.expert_keys) + if self.expert_keys: + cache_key = tuple(id(p) for p in params) + cache = self._expert_expand_cache.get(cache_key) + + if cache is None: + # Cold path: full expansion + build cache metadata. + exp_names, exp_params = _expand_expert_params( + names, params, self.expert_keys) + + # Build per-expert-group info for hot-path grad updates. + grad_info = [] + exp_idx = 0 + for orig_idx, (n, p) in enumerate(zip(names, params)): + if not is_expert_param(n, self.expert_keys): + exp_idx += 1 + continue + + is_dt = isinstance(p.data, DTensor) + num_experts = (p.to_local() if is_dt else p.data).shape[0] + + # Detect TP mesh from the first expanded expert param. + tp_mesh = None + tp_pls = None + sample = exp_params[exp_idx] + if isinstance(sample.data, DTensor): + tp_mesh = sample.data.device_mesh + tp_pls = list(sample.data.placements) + + grad_info.append((orig_idx, num_experts, exp_idx, is_dt, + tp_mesh, tp_pls)) + exp_idx += num_experts + + self._expert_expand_cache[cache_key] = { + 'names': exp_names, + 'params': exp_params, + 'grad_info': grad_info, + } + names, params = exp_names, exp_params + else: + # Hot path: reuse cached params, only update expert grads. + for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, + tp_pls) in cache['grad_info']: + p = params[orig_idx] + g = p.grad + local_grad = (g.to_local() + if is_dt and isinstance(g, DTensor) else g) + for i in range(num_experts): + expert_p = cache['params'][exp_start + i] + sg = local_grad[i] + if tp_mesh is not None: + expert_p.grad = DTensor.from_local( + sg, device_mesh=tp_mesh, placements=tp_pls) + else: + expert_p.grad = sg + p.grad = None + + names = cache['names'] + params = cache['params'] + else: + names, params = _expand_expert_params(names, params, + self.expert_keys) + logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] @@ -473,10 +838,10 @@ class Muon(torch.optim.Optimizer): param_tensors = [] name_tensors = [] - param_dtensors_small = [] - name_dtensors_small = [] - + # distributed_muon is a reference implementation for testing only. + # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: + _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, @@ -485,8 +850,6 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return - # For simplicity, we use distributed Muon for small parameters - # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -494,23 +857,28 @@ class Muon(torch.optim.Optimizer): if all( isinstance(placement, Replicate) for placement in p.placements): + logger.debug( + "[route] %s → base (DTensor all-Replicate), " + "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) - elif p.data.numel() <= self.small_param_numel_threshold: - param_dtensors_small.append(p) - name_dtensors_small.append(n) else: + logger.debug( + "[route] %s → parallel (DTensor), shape=%s, " + "placements=%s, mesh=%s", n, p.shape, p.placements, + p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): + logger.debug("[route] %s → base (plain tensor), shape=%s", n, + p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " - f"{len(param_dtensors_small)} Small DTensors") + logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " + f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements @@ -526,21 +894,6 @@ class Muon(torch.optim.Optimizer): p.device_mesh])][1].append(p) return placement_to_params - if len(param_dtensors_small) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - self.distributed_muon( - params=param_dtensors_small, - names=name_dtensors_small, - group=group, - lr=lr, - weight_decay=weight_decay, - qk_logits=qk_logits, - ) - if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( @@ -548,7 +901,26 @@ class Muon(torch.optim.Optimizer): ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) + + # Pre-launch the first chunk's A2A gather so that the NCCL + # communication overlaps with the (deferred) batched expert NS + # compute on the default CUDA stream. + prelaunch = None + if deferred_expert_work: + first_names, first_params = next(iter(dtensor_group.values())) + ordered, pts, rnk, csz = self._setup_parallel( + first_names, first_params, group, qk_logits) + first_chunk = ordered[:csz] + if first_chunk: + prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, + group["none_grad"]) + + _run_deferred_expert_ns() + + first_group = True for _, (names, params) in dtensor_group.items(): + pg = prelaunch if first_group else None + first_group = False self.parallel( names, params, @@ -556,7 +928,10 @@ class Muon(torch.optim.Optimizer): lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, + prelaunch_gather=pg, ) + else: + _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( @@ -568,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -585,10 +987,82 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + + logger.debug("[Muon.step] expert_keys=%s, %d param groups", + self.expert_keys, len(self.param_groups)) + + for i, group in enumerate(self.param_groups): if group["use_muon"]: + logger.debug("[Muon.step] group %d: use_muon=True, %d params", + i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: + logger.debug( + "[Muon.step] group %d: use_muon=False (AdamW), %d params", + i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py b/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py index f3fed6e6d186242df1e7e6e89b4416e31eb6bc63..2b1a938d06acf1a40985bda013a9061a8d42e407 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/newton_schulz.py @@ -1,3 +1,7 @@ +from itertools import repeat +from math import inf, sqrt + +import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign @@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +def _optimal_quintic(l, u, max_iter=1000): + """ + Use the simplified Remez algorithm to find the optimal odd quintic approximant + to the constant function x -> 1 over the interval [l, u]. + + Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum + approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the + two interior equioscillation nodes q, r until convergence. Returns the + closed-form equioscillating solution when l ≈ u. + + Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite + (NaN or inf). Raises RuntimeError if convergence is not reached within + max_iter iterations. + """ + assert 0 <= l <= u + if 1 - 5e-6 <= l / u: + return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) + q = (3 * l + u) / 4 + r = (l + 3 * u) / 4 + E = inf + for _ in range(max_iter): + old_E = E + LHS = np.array([ + [l, l**3, l**5, 1], + [q, q**3, q**5, -1], + [r, r**3, r**5, 1], + [u, u**3, u**5, -1], + ]) + a, b, c, E = np.linalg.solve(LHS, np.ones(4)) + if not np.all(np.isfinite([a, b, c, E])): + raise ValueError(f"_optimal_quintic: non-finite solve result " + f"a={a}, b={b}, c={c}, E={E}") + q, r = np.sqrt( + (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / + (10 * c)) + if not np.all(np.isfinite([q, r])): + raise ValueError( + f"_optimal_quintic: non-finite node update q={q}, r={r}") + if abs(old_E - E) <= 1e-15: + break + else: + raise RuntimeError( + f"_optimal_quintic: did not converge after {max_iter} iterations") + return float(a), float(b), float(c) + + +def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): + """ + Compute the Polar Express coefficient series for `num_iters` quintic iterations. + + Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that + compose to map singular values from [l, 1] toward 1. At each step: + 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` + prevents near-zero singular values from stalling by raising the effective + lower bound; if it is active (cushion*u > l), the coefficients are + rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. + 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the + last iteration, providing numerical headroom at the cost of a slightly slower + final convergence step. + 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). + + Returns a list of (a, b, c) tuples, one per iteration. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 + """ + u = 1 + assert 0 <= l <= u + safety_factor = 1 + safety_factor_eps + coefficients = [] + for iter in range(num_iters): + a, b, c = _optimal_quintic(max(l, cushion * u), u) + if cushion * u > l: + pl = a * l + b * l**3 + c * l**5 + pu = a * u + b * u**3 + c * u**5 + rescaler = 2 / (pl + pu) + a *= rescaler + b *= rescaler + c *= rescaler + if iter < num_iters - 1: + a /= safety_factor + b /= safety_factor**3 + c /= safety_factor**5 + coefficients.append((a, b, c)) + l = a * l + b * l**3 + c * l**5 + u = 2 - l + return coefficients + + +# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz +# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic +# approximant to x->1 over the current singular-value interval, computed once at +# import time and reused across all optimizer steps. +# +# Contrast with the former hardcoded NS coefficients (5 fixed tuples): +# - Former: empirically tuned to maximize slope at zero; did not converge +# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead +# of the true polar factor UV^T. +# - Polar Express: analytically optimal per step, adapting to the shrinking +# singular-value interval [l, u] as iterations progress; converges all +# singular values to 1, producing the exact polar factor UV^T. +_coeffs_list = _optimal_composition(l=1e-3, + num_iters=10, + safety_factor_eps=1e-2, + cushion=0.02) + + +# This code is adapted from: +# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) +# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) +# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon def _zeropower_via_newtonschulz5(G, steps): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Compute the polar factor of G via the Polar Express method. + + Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) + are the Polar Express coefficients from `_coeffs_list`. Each step is the + optimal odd quintic approximant to x -> 1 over the current singular-value + interval, minimizing the maximum approximation error (Remez / minimax criterion). + The composition maps singular values from [l, 1] to near 1, producing the + polar factor (orthogonal factor in the polar decomposition G = UP). + + `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, + cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. + + Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and + Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE @@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T - # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: + for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) @@ -47,4 +160,77 @@ def _zeropower_via_newtonschulz5(G, steps): if G.size(0) > G.size(1): X = X.T + return X + + +@torch.no_grad() +def _zeropower_via_newtonschulz5_batched(G, steps): + """Batched polar factor computation for 3D (E, out, in) tensors. + + Same algorithm as ``_zeropower_via_newtonschulz5`` but uses + ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, + processing all E expert matrices in a single batched call. + """ + assert len(G.shape) == 3 + assert G.dtype == COMM_DTYPE + X = G + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + # Per-expert Frobenius norm. + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + hs = _coeffs_list[:steps] + list( + repeat(_coeffs_list[-1], steps - len(_coeffs_list))) + for a, b, c in hs: + buf1 = torch.bmm(X, X.transpose(-2, -1)) + buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(1) > G.size(2): + X = X.transpose(-2, -1) + + return X + + +_ns_per_shape: dict[tuple[int, ...], callable] = {} +_use_compile = True + + +def set_ns_compile(enabled: bool): + """Toggle torch.compile for Newton-Schulz iteration.""" + global _use_compile + _use_compile = enabled + + +def zeropower_via_newtonschulz5(G, steps=5): + if not _use_compile: + return _zeropower_via_newtonschulz5(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() + + +def zeropower_via_newtonschulz5_batched(G, steps=5): + """Compile-cached batched Newton-Schulz for 3D expert tensors.""" + if not _use_compile: + return _zeropower_via_newtonschulz5_batched(G, steps) + key = G.shape + if key not in _ns_per_shape: + _ns_per_shape[key] = torch.compile( + _zeropower_via_newtonschulz5_batched, + options={ + "triton.cudagraphs": True, + "shape_padding": False + }) + torch.compiler.cudagraph_mark_step_begin() + return _ns_per_shape[key](G, steps).clone() diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/pipeline.py b/build/torch29-cxx11-rocm64-x86_64-linux/pipeline.py index 9241f6d4457e4a7eacc4129056eadef5aa6961f6..c0c2d515856182d8d15ad27dd4e4e093b29397d6 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/pipeline.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/pipeline.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function -from .core import _muon_state, adjust_lr_for_muon, update_p -from .newton_schulz import COMM_DTYPE, _zeropower_via_newtonschulz5 +from .core import _muon_state, adjust_lr_for_muon +from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) @@ -45,26 +45,33 @@ def _launch_gather( else: gathered_grads[id(p)] = None - # Build send buffer - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch grad copies via torch.cat + # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks - for p in params: state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = state.rank_numels[rank] - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in - per_dst), "At least one destination rank must receive a sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + send_counts[state.worker_rank] += state.rank_numels[rank] + + total_send = sum(send_counts) + if total_send > 0: + # Group grad slices by destination rank in a single pass. + dst_to_grads = [[] for _ in range(num_ranks)] + for p in params: + state = param_to_state[id(p)] + n = state.rank_numels[rank] + if n > 0: + g = p.grad.to_local() + dst_to_grads[state.worker_rank].append(g.reshape(-1)) + + # Flatten in dst order and cat once. + all_slices = [] + for dst in range(num_ranks): + all_slices.extend(dst_to_grads[dst]) + send_buf = torch.cat(all_slices) + if send_buf.dtype != COMM_DTYPE: + send_buf = send_buf.to(COMM_DTYPE) + else: + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks @@ -120,7 +127,8 @@ def _complete_gather( shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() - assert n > 0 + if n == 0: + continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) @@ -143,7 +151,7 @@ def _compute_ns( """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: - u = _zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) + u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us @@ -163,46 +171,47 @@ def _launch_scatter( Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). - scattered_us: ``{id(p): empty_local_tensor}`` for all params. + scattered_us: Empty dict, populated by ``_complete_scatter`` with + zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ - # Allocate scattered-u buffers + # scattered_us is populated by _complete_scatter with zero-copy views + # into recv_buf, avoiding N empty_like allocations + N copy_ calls. + # Pre-seed entries for params whose local shard is empty (rank_numels == 0) + # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: - scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) + if param_to_state[id(p)].rank_numels[rank] == 0: + scattered_us[id(p)] = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) - # Build send buffer (from computed_us on owner ranks) - per_dst: list[list[torch.Tensor]] = [[] for _ in range(num_ranks)] + # Build send buffer – batch via torch.cat + # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks - if owned_params: for p in owned_params: state = param_to_state[id(p)] - - assert computed_us[id(p)] is not None - u_full = computed_us[id(p)].to(COMM_DTYPE).contiguous() - - total_sent = 0 for dst_rank in range(num_ranks): - indices = state.rank_indices[dst_rank] - su = u_full[indices].flatten() - - n = su.numel() - assert n > 0 + send_counts[dst_rank] += state.rank_numels[dst_rank] - per_dst[dst_rank].append(su) - send_counts[dst_rank] += n - total_sent += n - - assert total_sent == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - per_dst_flat = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst_flat, dim=0) + total_send = sum(send_counts) + if total_send > 0: + # Cache u_full conversions to avoid redundant .to() per dst_rank. + u_fulls = {} + for p in owned_params: + u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() + + # Collect slices in dst order (matches all-to-all send layout). + all_slices = [] + for dst_rank in range(num_ranks): + for p in owned_params: + state = param_to_state[id(p)] + su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() + if su.numel() > 0: + all_slices.append(su) + + send_buf = torch.cat(all_slices) if all_slices else torch.empty( + 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") @@ -218,7 +227,6 @@ def _launch_scatter( recv_counts[src] = total recv_total = sum(recv_counts) - assert recv_total > 0 recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all @@ -242,7 +250,13 @@ def _complete_scatter( rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: - """Copy recv buffer into scattered_us (in-place).""" + """Populate scattered_us with zero-copy views into recv_buf. + + Instead of pre-allocating tensors and copying, we assign views directly + from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. + The underlying storage of ``recv_buf`` is kept alive through the views + until ``scattered_us`` is cleared after ``_update_params``. + """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] @@ -255,11 +269,11 @@ def _complete_scatter( if state.worker_rank != src: continue n = state.rank_numels[rank] - assert n > 0 + if n == 0: + continue - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - scattered_us[id(p)].copy_(flat_local) + scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) inner_off += n @@ -275,23 +289,40 @@ def _update_params( lr: float, weight_decay: float, ) -> None: - """Apply weight decay, Muon update, and optional QK clipping.""" - for p in params: - state = param_to_state[id(p)] - u_dtensor = DTensor.from_local( - scattered_us[id(p)], - placements=p.placements, - device_mesh=p.device_mesh, - ) + """Apply weight decay, Muon update, and optional QK clipping. + Uses batched ``_foreach_mul_`` for weight decay and batched + ``_foreach_add_`` for the Muon update, grouping parameters by + adjusted_lr to minimize kernel launches while preserving float32 + precision for the alpha scaling. + """ + if not params: + return + + # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. + p_locals = [p._local_tensor for p in params] + torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) + + # Group params by adjusted_lr so _foreach_add_ can use a single + # alpha per group (preserves float32 precision for alpha scaling). + lr_groups: dict[float, tuple[list, list]] = {} + for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) - update_p(p, u_dtensor, lr, adjusted_lr, weight_decay) + if adjusted_lr not in lr_groups: + lr_groups[adjusted_lr] = ([], []) + lr_groups[adjusted_lr][0].append(p._local_tensor) + lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) - # QK clipping – applied directly on the local tensor to - # avoid DTensor sharding-propagation issues with _StridedShard. - scales_full = compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None + for adjusted_lr, (p_group, u_group) in lr_groups.items(): + torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) + + # QK clipping – applied directly on the local tensor to + # avoid DTensor sharding-propagation issues with _StridedShard. + for p in params: + state = param_to_state[id(p)] + if state.qk_clip_state is None: + continue + scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] @@ -304,6 +335,45 @@ def _update_params( p._local_tensor.mul_(row_scales.view(-1, 1)) +# ====================================================================== +# Pre-launch helper for overlapping first chunk's gather with other work. +# ====================================================================== + + +@torch.no_grad() +def prelaunch_first_gather( + params: list[DTensor], + param_to_state: dict[int, _muon_state], + rank: int, + none_grad: bool, +) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: + """Launch the first chunk's A2A gather early for overlap with other compute. + + Call this *before* expensive GPU work (e.g. batched expert NS) so that + the NCCL all-to-all runs concurrently on the NCCL stream while the + default stream executes compute. + + Returns the same 4-tuple that ``_launch_gather`` produces, which should + be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. + """ + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + with record_function("muon::prelaunch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None + + return work, recv_buf, gathered_grads, recv_counts + + # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @@ -318,6 +388,7 @@ def muon_chunk_pipeline( lr: float, weight_decay: float, none_grad: bool, + prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. @@ -334,9 +405,12 @@ def muon_chunk_pipeline( runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. + If ``prelaunch_gather`` is provided, the gather was already launched + by :func:`prelaunch_first_gather` and we skip launching it again. + Yields exactly **2** times: - 1. After launching async all-to-all gather. + 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group @@ -345,15 +419,19 @@ def muon_chunk_pipeline( p for p in params if param_to_state[id(p)].worker_rank == rank ] - # Stages 1-2: launch async gather. - with record_function("muon::launch_gather"): - work, recv_buf, gathered_grads, recv_counts = _launch_gather( - params, owned_params, param_to_state, rank, num_ranks, - process_group) - - if none_grad: - for p in params: - p.grad = None + if prelaunch_gather is not None: + # Gather was pre-launched; none_grad already handled by caller. + work, recv_buf, gathered_grads, recv_counts = prelaunch_gather + else: + # Normal path: launch async gather. + with record_function("muon::launch_gather"): + work, recv_buf, gathered_grads, recv_counts = _launch_gather( + params, owned_params, param_to_state, rank, num_ranks, + process_group) + + if none_grad: + for p in params: + p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py b/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py index 0d8f7199afa361bfb011ebdd4ed84b03709aaee7..9bd14b01bb8fa00e246ee34d2483616b4f3230ed 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/qk_clip.py @@ -5,6 +5,8 @@ from dataclasses import dataclass import torch from torch.distributed.tensor import DTensor +from .core import normalize_fqn + logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]: 'model.7.attn.k_proj.weight' -> ('k_proj', 7) 'model.4.attn.v_proj.weight' -> (None, -1) """ - parts = name.split('.') + parts = normalize_fqn(name).split('.') if len(parts) < 3: return None, -1 @@ -100,23 +102,27 @@ def compute_scales(p, qk_clip_state): threshold = qk_clip_state.threshold logit = qk_clip_state.logit - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - + # Check if any head exceeds threshold before allocating. + head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale + if head_idx not in head_scales or new_scale < head_scales[head_idx]: + head_scales[head_idx] = new_scale logger.info( f"[{kind}] Head {head_idx} exceeded threshold " f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" ) - scaling += 1 - return scales_full if scaling > 0 else None + if not head_scales: + return None + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + for head_idx, scale in head_scales.items(): + scales_full[head_idx] = scale + return scales_full def qk_clip(p, scales, head_dim): diff --git a/test/test_cpu_memory_peak.py b/test/test_cpu_memory_peak.py new file mode 100644 index 0000000000000000000000000000000000000000..ba9864d229f6ca9536eec53ec6d5c3955c1ea157 --- /dev/null +++ b/test/test_cpu_memory_peak.py @@ -0,0 +1,541 @@ +"""CPU memory peak vs offloaded tensor size verification. + +Compares CPU memory usage with cpu_offload=True vs False to isolate +the actual CPU cost of offloading, separating it from CUDA runtime, +NCCL, and DTensor overhead. + +Run with: + torchrun --nproc-per-node=8 --local-ranks-filter=0 test/test_cpu_memory_peak.py +""" + +import gc +import logging +import os + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor, Shard, distribute_tensor + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") + + +def _setup(): + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + torch.cuda.set_device(rank % torch.cuda.device_count()) + return rank, dist.get_world_size() + + +def _make_mesh(world_size): + return dist.init_device_mesh("cuda", (world_size, ), + mesh_dim_names=("dp", )) + + +def get_cpu_rss_bytes(): + """Get current process RSS in bytes from /proc/self/statm.""" + with open("/proc/self/statm") as f: + pages = int(f.read().split()[1]) + return pages * os.sysconf("SC_PAGE_SIZE") + + +def get_pinned_pool_bytes(pool): + """Get total pinned CPU buffer size from CPUOffloadPool.""" + total = 0 + for grp in pool._groups.values(): + cpu_flat = grp["cpu_flat"] + total += cpu_flat.numel() * cpu_flat.element_size() + return total + + +def _run_muon_steps(mesh, dim0, dim1, num_params, num_steps, cpu_offload): + """Run Muon optimizer steps and return final CPU RSS.""" + from optimizer.muon import Muon + from optimizer.newton_schulz import set_ns_compile + + set_ns_compile(False) + torch.manual_seed(42) + gc.collect() + torch.cuda.empty_cache() + + params, names = [], [] + for i in range(num_params): + full = torch.randn(dim0, dim1, device="cuda") + dt = distribute_tensor(full, mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + params.append(p) + names.append(f"layer.{i}.weight") + + param_groups = [{ + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + }] + + optim = Muon(params=param_groups, + chunk_size=2, + warmup_step=1, + cpu_offload=cpu_offload) + + for step_idx in range(num_steps): + for p in params: + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + optim.step() + torch.cuda.synchronize() + + gc.collect() + cpu_rss = get_cpu_rss_bytes() + + pinned_bytes = 0 + if cpu_offload and hasattr(optim, '_cpu_offload_pool'): + pool = optim._cpu_offload_pool + pinned_bytes = get_pinned_pool_bytes(pool) + + # Cleanup. + del optim, params, param_groups + gc.collect() + torch.cuda.empty_cache() + + set_ns_compile(True) + return cpu_rss, pinned_bytes + + +def test_offload_cpu_cost_isolation(rank, world_size): + """A/B test: measure CPU cost of offload by comparing ON vs OFF.""" + mesh = _make_mesh(world_size) + + dim0, dim1 = 2048, 4096 + num_params = 8 + num_steps = 3 + + if rank == 0: + logger.info("=" * 70) + logger.info("A/B TEST: CPU MEMORY COST OF OFFLOAD (ON vs OFF)") + logger.info("=" * 70) + logger.info("Config: %d params of shape (%d, %d), %d ranks, %d steps", + num_params, dim0, dim1, world_size, num_steps) + logger.info("Local param shape per rank: (%d, %d)", dim0 // world_size, + dim1) + logger.info("-" * 70) + + # Run WITHOUT offload first (baseline). + gc.collect() + torch.cuda.empty_cache() + cpu_before_no_offload = get_cpu_rss_bytes() + cpu_after_no_offload, _ = _run_muon_steps(mesh, + dim0, + dim1, + num_params, + num_steps, + cpu_offload=False) + cpu_growth_no_offload = cpu_after_no_offload - cpu_before_no_offload + + # Run WITH offload. + gc.collect() + torch.cuda.empty_cache() + cpu_before_offload = get_cpu_rss_bytes() + cpu_after_offload, pinned_bytes = _run_muon_steps(mesh, + dim0, + dim1, + num_params, + num_steps, + cpu_offload=True) + cpu_growth_offload = cpu_after_offload - cpu_before_offload + + # Delta = additional CPU cost from offloading. + offload_delta = cpu_growth_offload - cpu_growth_no_offload + + if rank == 0: + logger.info("CPU growth WITHOUT offload: %.2f MB", + cpu_growth_no_offload / 1024**2) + logger.info("CPU growth WITH offload: %.2f MB", + cpu_growth_offload / 1024**2) + logger.info("-" * 70) + logger.info("Pinned buffer size (expected): %.2f MB", + pinned_bytes / 1024**2) + logger.info("Offload delta (WITH - WITHOUT): %.2f MB", + offload_delta / 1024**2) + + if pinned_bytes > 0: + ratio = offload_delta / pinned_bytes + logger.info("Ratio (delta / pinned buffer): %.2fx", ratio) + + if ratio > 1.5: + logger.warning( + "Offload adds %.2f MB CPU memory but pinned buffer is " + "only %.2f MB (%.1f%% overhead beyond expected)", + offload_delta / 1024**2, + pinned_bytes / 1024**2, + (offload_delta - pinned_bytes) / pinned_bytes * 100, + ) + else: + logger.info("Offload CPU cost is within expected range.") + + # Only assert on rank 0 to avoid multi-rank assertion mismatches. + if rank == 0 and pinned_bytes > 0: + ratio = offload_delta / pinned_bytes + assert ratio < 3.0, ( + f"Offload CPU cost ({offload_delta / 1024**2:.2f} MB) is " + f"{ratio:.2f}x the pinned buffer ({pinned_bytes / 1024**2:.2f} MB). " + f"Expected < 3.0x.") + + if rank == 0: + logger.info("PASSED: test_offload_cpu_cost_isolation") + + +def test_cpu_memory_peak_detailed(rank, world_size): + """Detailed per-phase CPU memory tracking for offload.""" + from optimizer.muon import Muon + from optimizer.newton_schulz import set_ns_compile + + set_ns_compile(False) + torch.manual_seed(42) + + mesh = _make_mesh(world_size) + + dim0, dim1 = 2048, 4096 + num_params = 8 + + gc.collect() + torch.cuda.empty_cache() + + if rank == 0: + logger.info("=" * 70) + logger.info("DETAILED PER-PHASE CPU MEMORY TRACKING") + logger.info("=" * 70) + + cpu_0 = get_cpu_rss_bytes() + if rank == 0: + logger.info("[Phase 0] Baseline RSS: %.2f MB", cpu_0 / 1024**2) + + # Create params. + params, names = [], [] + for i in range(num_params): + full = torch.randn(dim0, dim1, device="cuda") + dt = distribute_tensor(full, mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + params.append(p) + names.append(f"layer.{i}.weight") + + gc.collect() + cpu_1 = get_cpu_rss_bytes() + if rank == 0: + logger.info("[Phase 1] After param creation: %.2f MB (+%.2f MB)", + cpu_1 / 1024**2, (cpu_1 - cpu_0) / 1024**2) + + # Create optimizer. + param_groups = [{ + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + }] + optim = Muon(params=param_groups, + chunk_size=2, + warmup_step=1, + cpu_offload=True) + + gc.collect() + cpu_2 = get_cpu_rss_bytes() + if rank == 0: + logger.info("[Phase 2] After optimizer creation: %.2f MB (+%.2f MB)", + cpu_2 / 1024**2, (cpu_2 - cpu_1) / 1024**2) + + # Set grads. + for p in params: + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + + gc.collect() + cpu_3 = get_cpu_rss_bytes() + if rank == 0: + logger.info("[Phase 3] After grad creation: %.2f MB (+%.2f MB)", + cpu_3 / 1024**2, (cpu_3 - cpu_2) / 1024**2) + + # Step 1 (creates states + first offload). + optim.step() + torch.cuda.synchronize() + gc.collect() + cpu_4 = get_cpu_rss_bytes() + + pool = optim._cpu_offload_pool + pinned_bytes = get_pinned_pool_bytes(pool) + + if rank == 0: + logger.info( + "[Phase 4] After step 1 (init+offload): %.2f MB (+%.2f MB)", + cpu_4 / 1024**2, (cpu_4 - cpu_3) / 1024**2) + logger.info(" Pinned buffer size: %.2f MB", pinned_bytes / 1024**2) + logger.info(" Step 1 growth vs pinned: %.2f MB extra", + (cpu_4 - cpu_3 - pinned_bytes) / 1024**2) + + # Step 2 (reload + compute + offload). + for p in params: + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + optim.step() + torch.cuda.synchronize() + gc.collect() + cpu_5 = get_cpu_rss_bytes() + if rank == 0: + logger.info("[Phase 5] After step 2: %.2f MB (+%.2f MB)", + cpu_5 / 1024**2, (cpu_5 - cpu_4) / 1024**2) + + # Step 3. + for p in params: + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + optim.step() + torch.cuda.synchronize() + gc.collect() + cpu_6 = get_cpu_rss_bytes() + if rank == 0: + logger.info("[Phase 6] After step 3: %.2f MB (+%.2f MB)", + cpu_6 / 1024**2, (cpu_6 - cpu_5) / 1024**2) + + # Summary. + total_growth = cpu_6 - cpu_0 + if rank == 0: + logger.info("-" * 70) + logger.info("SUMMARY:") + logger.info(" Total CPU growth: %.2f MB", total_growth / 1024**2) + logger.info(" Pinned buffer: %.2f MB", pinned_bytes / 1024**2) + logger.info(" Overhead: %.2f MB", + (total_growth - pinned_bytes) / 1024**2) + if pinned_bytes > 0: + logger.info(" Ratio: %.2fx", + total_growth / pinned_bytes) + logger.info("") + logger.info(" NOTE: Overhead includes CUDA runtime, NCCL buffers,") + logger.info(" DTensor metadata, and optimizer internals — NOT just") + logger.info(" offload cost. Use A/B test for isolated measurement.") + + set_ns_compile(True) + if rank == 0: + logger.info("PASSED: test_cpu_memory_peak_detailed") + + +def test_offload_cpu_cost_mixed(rank, world_size): + """A/B test for mixed Muon + AdamW offload CPU cost.""" + from optimizer.muon import Muon + from optimizer.newton_schulz import set_ns_compile + + mesh = _make_mesh(world_size) + + muon_dim0, muon_dim1 = 2048, 4096 + num_muon = 8 + adamw_dim = 4096 + num_adamw = 8 + num_steps = 3 + + def run_mixed(cpu_offload): + set_ns_compile(False) + torch.manual_seed(42) + gc.collect() + torch.cuda.empty_cache() + + muon_params, muon_names = [], [] + for i in range(num_muon): + full = torch.randn(muon_dim0, muon_dim1, device="cuda") + dt = distribute_tensor(full, mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + muon_params.append(p) + muon_names.append(f"layer.{i}.weight") + + adamw_params = [] + for i in range(num_adamw): + full = torch.randn(adamw_dim, device="cuda") + dt = distribute_tensor(full, mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + adamw_params.append(p) + + param_groups = [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + "adamw_betas": (0.9, 0.95), + "adamw_eps": 1e-8, + }, + { + "params": adamw_params, + "use_muon": False, + "lr": 1e-3, + "weight_decay": 0.01, + "adamw_betas": (0.9, 0.95), + "adamw_eps": 1e-8, + }, + ] + + optim = Muon(params=param_groups, + chunk_size=2, + warmup_step=1, + cpu_offload=cpu_offload) + + for step_idx in range(num_steps): + for p in muon_params: + p.grad = distribute_tensor( + torch.randn(muon_dim0, muon_dim1, device="cuda"), mesh, + [Shard(0)]) + for p in adamw_params: + p.grad = distribute_tensor( + torch.randn(adamw_dim, device="cuda"), mesh, [Shard(0)]) + optim.step() + torch.cuda.synchronize() + + gc.collect() + cpu_rss = get_cpu_rss_bytes() + + pinned_bytes = 0 + if cpu_offload and hasattr(optim, '_cpu_offload_pool'): + pinned_bytes = get_pinned_pool_bytes(optim._cpu_offload_pool) + + del optim, muon_params, adamw_params, param_groups + gc.collect() + torch.cuda.empty_cache() + set_ns_compile(True) + return cpu_rss, pinned_bytes + + if rank == 0: + logger.info("=" * 70) + logger.info("A/B TEST: CPU COST OF MIXED OFFLOAD (Muon + AdamW)") + logger.info("=" * 70) + + gc.collect() + torch.cuda.empty_cache() + cpu_before_no = get_cpu_rss_bytes() + cpu_after_no, _ = run_mixed(False) + growth_no = cpu_after_no - cpu_before_no + + gc.collect() + torch.cuda.empty_cache() + cpu_before_yes = get_cpu_rss_bytes() + cpu_after_yes, pinned_bytes = run_mixed(True) + growth_yes = cpu_after_yes - cpu_before_yes + + delta = growth_yes - growth_no + + if rank == 0: + logger.info("CPU growth WITHOUT offload: %.2f MB", growth_no / 1024**2) + logger.info("CPU growth WITH offload: %.2f MB", + growth_yes / 1024**2) + logger.info("Pinned buffer size: %.2f MB", + pinned_bytes / 1024**2) + logger.info("Offload delta: %.2f MB", delta / 1024**2) + if pinned_bytes > 0: + logger.info("Ratio (delta / pinned): %.2fx", + delta / pinned_bytes) + + if rank == 0 and pinned_bytes > 0: + ratio = delta / pinned_bytes + assert ratio < 3.0, ( + f"Mixed offload CPU cost ({delta / 1024**2:.2f} MB) is " + f"{ratio:.2f}x the pinned buffer ({pinned_bytes / 1024**2:.2f} MB)." + ) + + if rank == 0: + logger.info("PASSED: test_offload_cpu_cost_mixed") + + +def test_pinned_memory_rss_overhead(rank, world_size): + """Isolate: does cudaHostAlloc itself cause 2x RSS overhead?""" + sizes_mb = [8, 16, 32, 64, 128] + + if rank == 0: + logger.info("=" * 70) + logger.info("ISOLATED TEST: PINNED MEMORY RSS OVERHEAD") + logger.info("=" * 70) + + for size_mb in sizes_mb: + numel = size_mb * 1024 * 1024 // 4 # float32 + + # Test 1: pin_memory=True (direct allocation). + gc.collect() + torch.cuda.empty_cache() + rss_before = get_cpu_rss_bytes() + t1 = torch.empty(numel, + dtype=torch.float32, + device="cpu", + pin_memory=True) + rss_after = get_cpu_rss_bytes() + rss_growth_direct = rss_after - rss_before + del t1 + gc.collect() + + # Test 2: .pin_memory() (copy-based). + gc.collect() + torch.cuda.empty_cache() + rss_before2 = get_cpu_rss_bytes() + t2 = torch.empty(numel, dtype=torch.float32, device="cpu").pin_memory() + rss_after2 = get_cpu_rss_bytes() + rss_growth_copy = rss_after2 - rss_before2 + del t2 + gc.collect() + + # Test 3: regular (non-pinned) CPU allocation. + gc.collect() + torch.cuda.empty_cache() + rss_before3 = get_cpu_rss_bytes() + t3 = torch.empty(numel, dtype=torch.float32, device="cpu") + # Touch all pages to ensure RSS reflects actual allocation. + t3.fill_(1.0) + rss_after3 = get_cpu_rss_bytes() + rss_growth_regular = rss_after3 - rss_before3 + del t3 + gc.collect() + + if rank == 0: + logger.info( + "%3d MB: pin_memory=True → RSS +%.1f MB (%.2fx) | " + ".pin_memory() → RSS +%.1f MB (%.2fx) | " + "regular → RSS +%.1f MB (%.2fx)", + size_mb, + rss_growth_direct / 1024**2, + rss_growth_direct / (size_mb * 1024**2) if size_mb > 0 else 0, + rss_growth_copy / 1024**2, + rss_growth_copy / (size_mb * 1024**2) if size_mb > 0 else 0, + rss_growth_regular / 1024**2, + rss_growth_regular / (size_mb * 1024**2) if size_mb > 0 else 0, + ) + + if rank == 0: + logger.info("PASSED: test_pinned_memory_rss_overhead") + + +def main(): + rank, world_size = _setup() + + try: + test_pinned_memory_rss_overhead(rank, world_size) + test_cpu_memory_peak_detailed(rank, world_size) + test_offload_cpu_cost_isolation(rank, world_size) + test_offload_cpu_cost_mixed(rank, world_size) + + if rank == 0: + logger.info("=" * 50) + logger.info("ALL CPU MEMORY PEAK TESTS PASSED") + logger.info("=" * 50) + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/test/test_cpu_offload.py b/test/test_cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..502ce56dda64013ae6f7c334df05a4cc4471afdc --- /dev/null +++ b/test/test_cpu_offload.py @@ -0,0 +1,859 @@ +"""CPU offloading tests for optimizer states. + +Run with: + torchrun --nproc-per-node=8 --local-ranks-filter=0 test/test_cpu_offload.py + +Tests: + 1. Correctness: cpu_offload=True produces identical results to False + 2. Memory: GPU optimizer state storage is freed after offload + 3. AdamW: moment1/moment2 offloading works correctly +""" + +import logging +import sys + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor, Shard, distribute_tensor + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") + + +def _setup(): + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + torch.cuda.set_device(rank % torch.cuda.device_count()) + return rank, dist.get_world_size() + + +def _make_mesh(world_size): + return dist.init_device_mesh("cuda", (world_size, ), + mesh_dim_names=("dp", )) + + +def test_correctness(rank, world_size): + """Verify that cpu_offload=True produces identical parameters as False.""" + from optimizer.muon import Muon + from optimizer.newton_schulz import set_ns_compile + + set_ns_compile(False) + torch.manual_seed(42) + + mesh = _make_mesh(world_size) + + dim0, dim1 = 64, 128 + num_params = 4 + num_steps = 3 + + # Pre-generate all data on all ranks (same seed → same values). + full_params = [ + torch.randn(dim0, dim1, device="cuda") for _ in range(num_params) + ] + full_grads = [[ + torch.randn(dim0, dim1, device="cuda") for _ in range(num_params) + ] for _ in range(num_steps)] + + def make_optimizer(cpu_offload): + params, names = [], [] + for i, fp in enumerate(full_params): + dt = distribute_tensor(fp.clone(), mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + params.append(p) + names.append(f"layer.{i}.weight") + param_groups = [{ + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + }] + optim = Muon(params=param_groups, + chunk_size=2, + warmup_step=1, + cpu_offload=cpu_offload) + return optim, params + + optim_ref, params_ref = make_optimizer(False) + optim_off, params_off = make_optimizer(True) + + for step_idx in range(num_steps): + for i in range(num_params): + g = full_grads[step_idx][i] + params_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) + params_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) + + optim_ref.step() + optim_off.step() + + for i in range(num_params): + ref_full = params_ref[i].data.full_tensor() + off_full = params_off[i].data.full_tensor() + torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0) + + if rank == 0: + logger.info("Step %d: correctness OK", step_idx) + + set_ns_compile(True) + if rank == 0: + logger.info("PASSED: test_correctness") + + +def test_memory(rank, world_size): + """Verify that GPU storage is freed after offload.""" + from optimizer.muon import Muon + from optimizer.newton_schulz import set_ns_compile + + set_ns_compile(False) + torch.manual_seed(42) + + mesh = _make_mesh(world_size) + + dim0, dim1 = 512, 1024 + num_params = 8 + + params, names = [], [] + for i in range(num_params): + full = torch.randn(dim0, dim1, device="cuda") + dt = distribute_tensor(full, mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + params.append(p) + names.append(f"layer.{i}.weight") + + param_groups = [{ + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + }] + optim = Muon(params=param_groups, + chunk_size=2, + warmup_step=1, + cpu_offload=True) + + optim.step() + torch.cuda.synchronize() + + # After step + offload, all momentum buffer GPU storage should be freed. + for p in params: + state = optim.state[p] + if "momentum_buffer" not in state: + continue + buf = state["momentum_buffer"] + local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf + assert local_buf.untyped_storage().size() == 0, ( + f"Expected freed GPU storage after offload, got " + f"{local_buf.untyped_storage().size()} bytes") + + # Verify CPU pool has pinned buffers. + pool = optim._cpu_offload_pool + assert len(pool._managed) > 0, "No tensors tracked by CPU offload pool" + for grp in pool._groups.values(): + assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned memory" + + # Run another step to verify reload + compute + offload cycle works. + for p in params: + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + optim.step() + torch.cuda.synchronize() + + # Storage should be freed again after second step. + for p in params: + state = optim.state[p] + if "momentum_buffer" not in state: + continue + buf = state["momentum_buffer"] + local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf + assert local_buf.untyped_storage().size() == 0 + + set_ns_compile(True) + if rank == 0: + logger.info("PASSED: test_memory") + + +def test_adamw_offload(rank, world_size): + """Verify AdamW moment1/moment2 are offloaded correctly.""" + from optimizer.muon import Muon + from optimizer.newton_schulz import set_ns_compile + + set_ns_compile(False) + torch.manual_seed(42) + + mesh = _make_mesh(world_size) + + num_steps = 3 + + # Create both Muon (2D) and AdamW (1D) params. + muon_params, muon_names = [], [] + adamw_params, adamw_names = [], [] + + for i in range(4): + full = torch.randn(64, 128, device="cuda") + dt = distribute_tensor(full, mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + muon_params.append(p) + muon_names.append(f"layer.{i}.weight") + + for i in range(3): + full = torch.randn(128, device="cuda") + dt = distribute_tensor(full, mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + adamw_params.append(p) + adamw_names.append(f"layer.{i}.bias") + + # Pre-generate grads. + muon_grads = [[torch.randn(64, 128, device="cuda") for _ in range(4)] + for _ in range(num_steps)] + adamw_grads = [[torch.randn(128, device="cuda") for _ in range(3)] + for _ in range(num_steps)] + + def make_optimizer(cpu_offload): + mp = [ + torch.nn.Parameter( + distribute_tensor(p.data.full_tensor().clone(), mesh, + [Shard(0)])) for p in muon_params + ] + ap = [ + torch.nn.Parameter( + distribute_tensor(p.data.full_tensor().clone(), mesh, + [Shard(0)])) for p in adamw_params + ] + param_groups = [ + { + "params": mp, + "names": list(muon_names), + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + "adamw_betas": (0.9, 0.95), + "adamw_eps": 1e-8, + }, + { + "params": ap, + "use_muon": False, + "lr": 1e-3, + "weight_decay": 0.01, + "adamw_betas": (0.9, 0.95), + "adamw_eps": 1e-8, + }, + ] + optim = Muon(params=param_groups, + chunk_size=2, + warmup_step=1, + cpu_offload=cpu_offload) + return optim, mp, ap + + optim_ref, mp_ref, ap_ref = make_optimizer(False) + optim_off, mp_off, ap_off = make_optimizer(True) + + for step_idx in range(num_steps): + for i in range(4): + g = muon_grads[step_idx][i] + mp_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) + mp_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) + for i in range(3): + g = adamw_grads[step_idx][i] + ap_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) + ap_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) + + optim_ref.step() + optim_off.step() + + # Compare Muon params. + for i in range(4): + ref_full = mp_ref[i].data.full_tensor() + off_full = mp_off[i].data.full_tensor() + torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0) + + # Compare AdamW params. + for i in range(3): + ref_full = ap_ref[i].data.full_tensor() + off_full = ap_off[i].data.full_tensor() + torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0) + + if rank == 0: + logger.info("Step %d: AdamW offload correctness OK", step_idx) + + # Verify AdamW states are offloaded. + for p in ap_off: + state = optim_off.state[p] + for key in ("moment1", "moment2"): + if key not in state: + continue + t = state[key] + local_t = t._local_tensor if isinstance(t, DTensor) else t + assert local_t.untyped_storage().size() == 0, ( + f"AdamW {key} storage not freed after offload") + + set_ns_compile(True) + if rank == 0: + logger.info("PASSED: test_adamw_offload") + + +def test_memory_savings(rank, world_size): + """Measure actual GPU memory difference with and without offload.""" + from optimizer.muon import Muon + from optimizer.newton_schulz import set_ns_compile + + set_ns_compile(False) + + mesh = _make_mesh(world_size) + dim0, dim1 = 1024, 2048 + num_params = 8 + + def run_step(cpu_offload): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.manual_seed(42) + + params, names = [], [] + for i in range(num_params): + full = torch.randn(dim0, dim1, device="cuda") + dt = distribute_tensor(full, mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + params.append(p) + names.append(f"layer.{i}.weight") + + param_groups = [{ + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + }] + optim = Muon(params=param_groups, + chunk_size=2, + warmup_step=1, + cpu_offload=cpu_offload) + optim.step() + torch.cuda.synchronize() + + mem = torch.cuda.memory_allocated() + # Clean up to avoid interference. + del optim, params, param_groups + torch.cuda.empty_cache() + return mem + + mem_no_offload = run_step(False) + mem_with_offload = run_step(True) + + if rank == 0: + logger.info("Memory without offload: %.2f MB", + mem_no_offload / 1024**2) + logger.info("Memory with offload: %.2f MB", + mem_with_offload / 1024**2) + saved = mem_no_offload - mem_with_offload + logger.info("Memory saved: %.2f MB", saved / 1024**2) + + assert mem_with_offload < mem_no_offload, ( + f"Expected memory reduction with CPU offload. " + f"Without: {mem_no_offload / 1024**2:.2f} MB, " + f"With: {mem_with_offload / 1024**2:.2f} MB") + + set_ns_compile(True) + if rank == 0: + logger.info("PASSED: test_memory_savings") + + +def test_leak(rank, world_size): + """Run many iterations and verify no CPU/GPU memory leak.""" + import os + + from optimizer.muon import Muon + from optimizer.newton_schulz import set_ns_compile + + set_ns_compile(False) + torch.manual_seed(42) + + mesh = _make_mesh(world_size) + + dim0, dim1 = 512, 1024 + num_params = 8 + num_steps = 50 + + params, names = [], [] + for i in range(num_params): + full = torch.randn(dim0, dim1, device="cuda") + dt = distribute_tensor(full, mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + params.append(p) + names.append(f"layer.{i}.weight") + + param_groups = [{ + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + }] + optim = Muon(params=param_groups, + chunk_size=2, + warmup_step=1, + cpu_offload=True) + + def get_cpu_rss_mb(): + """Get current process RSS in MB from /proc/self/statm.""" + with open("/proc/self/statm") as f: + pages = int(f.read().split()[1]) + return pages * os.sysconf("SC_PAGE_SIZE") / (1024**2) + + gpu_after_warmup = None + cpu_after_warmup = None + + for step_idx in range(num_steps): + for p in params: + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + + optim.step() + torch.cuda.synchronize() + + gpu_mem = torch.cuda.memory_allocated() + cpu_mem = get_cpu_rss_mb() + + # Record baseline after warmup (step 2 — first step creates states, + # second step does first full offload/reload cycle). + if step_idx == 2: + gpu_after_warmup = gpu_mem + cpu_after_warmup = cpu_mem + + if rank == 0 and step_idx % 10 == 0: + logger.info("Step %d: GPU alloc=%.2f MB, CPU RSS=%.2f MB", + step_idx, gpu_mem / (1024**2), cpu_mem) + + # Final measurements. + torch.cuda.synchronize() + gpu_final = torch.cuda.memory_allocated() + cpu_final = get_cpu_rss_mb() + + if rank == 0: + logger.info("After %d steps: GPU alloc=%.2f MB, CPU RSS=%.2f MB", + num_steps, gpu_final / (1024**2), cpu_final) + logger.info("Warmup baseline: GPU alloc=%.2f MB, CPU RSS=%.2f MB", + gpu_after_warmup / (1024**2), cpu_after_warmup) + + # GPU memory should not grow beyond warmup baseline. + assert gpu_final <= gpu_after_warmup, ( + f"GPU memory leak detected! Warmup: {gpu_after_warmup / 1024**2:.2f} MB, " + f"Final: {gpu_final / 1024**2:.2f} MB") + + # CPU RSS should not grow more than 50 MB over warmup (allows for minor + # Python/CUDA runtime overhead but catches real leaks). + cpu_growth = cpu_final - cpu_after_warmup + assert cpu_growth < 50, ( + f"CPU memory leak detected! Growth: {cpu_growth:.2f} MB over " + f"{num_steps - 2} steps (warmup={cpu_after_warmup:.2f} MB, " + f"final={cpu_final:.2f} MB)") + + set_ns_compile(True) + if rank == 0: + logger.info("PASSED: test_leak (GPU stable, CPU growth=%.2f MB)", + cpu_growth) + + +def test_state_dict_save_load(rank, world_size): + """Verify state_dict() works after offload and load_state_dict() resumes correctly. + + Uses torch.distributed.checkpoint (DCP) for serialization, matching + the actual LLM training checkpoint flow. DCP handles DTensors natively + so the roundtrip is bitwise exact. + """ + import shutil + import tempfile + + import torch.distributed.checkpoint as dcp + from optimizer.muon import Muon + from optimizer.newton_schulz import set_ns_compile + + set_ns_compile(False) + torch.manual_seed(42) + + mesh = _make_mesh(world_size) + + dim0, dim1 = 64, 128 + num_muon = 4 + num_adamw = 3 + num_steps = 3 + + # Pre-generate all data. + muon_init = [ + torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon) + ] + adamw_init = [torch.randn(dim1, device="cuda") for _ in range(num_adamw)] + all_grads_muon = [[ + torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon) + ] for _ in range(num_steps * 2)] + all_grads_adamw = [[ + torch.randn(dim1, device="cuda") for _ in range(num_adamw) + ] for _ in range(num_steps * 2)] + + def make_optimizer(cpu_offload): + mp = [ + torch.nn.Parameter( + distribute_tensor(muon_init[i].clone(), mesh, [Shard(0)])) + for i in range(num_muon) + ] + ap = [ + torch.nn.Parameter( + distribute_tensor(adamw_init[i].clone(), mesh, [Shard(0)])) + for i in range(num_adamw) + ] + param_groups = [ + { + "params": mp, + "names": [f"layer.{i}.weight" for i in range(num_muon)], + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + "adamw_betas": (0.9, 0.95), + "adamw_eps": 1e-8, + }, + { + "params": ap, + "use_muon": False, + "lr": 1e-3, + "weight_decay": 0.01, + "adamw_betas": (0.9, 0.95), + "adamw_eps": 1e-8, + }, + ] + optim = Muon(params=param_groups, + chunk_size=2, + warmup_step=1, + cpu_offload=cpu_offload) + return optim, mp, ap + + # --- Run one optimizer for first half, save state, then create TWO + # fresh optimizers: ref loads via deepcopy, resumed loads via DCP. + # Both are fresh → same internal cache state → isolates DCP fidelity. + optim_off, mp_off, ap_off = make_optimizer(True) + + for step_idx in range(num_steps): + for i in range(num_muon): + mp_off[i].grad = distribute_tensor( + all_grads_muon[step_idx][i].clone(), mesh, [Shard(0)]) + for i in range(num_adamw): + ap_off[i].grad = distribute_tensor( + all_grads_adamw[step_idx][i].clone(), mesh, [Shard(0)]) + optim_off.step() + + sd_off = optim_off.state_dict() + + # Verify state tensors are NOT empty in the state_dict. + for param_states in sd_off["state"].values(): + for key, val in param_states.items(): + if isinstance(val, torch.Tensor) and val.is_floating_point(): + assert val.untyped_storage().size() > 0, ( + f"state_dict() returned empty storage for key '{key}' — " + f"offload reload is broken") + + if rank == 0: + logger.info("state_dict() contains valid (non-empty) tensors") + + # Save state tensors via DCP (matches real LLM training checkpoint flow). + # Flatten state tensors with string keys for DCP compatibility. + flat_state = {} + for param_idx, param_state in sd_off["state"].items(): + for key, val in param_state.items(): + if isinstance(val, torch.Tensor): + flat_state[f"state.{param_idx}.{key}"] = val + + # All ranks must use the same checkpoint directory. + if rank == 0: + ckpt_dir = tempfile.mkdtemp(prefix="cpu_offload_test_") + else: + ckpt_dir = "" + ckpt_dir_list = [ckpt_dir] + dist.broadcast_object_list(ckpt_dir_list, src=0) + ckpt_dir = ckpt_dir_list[0] + try: + dcp.save(flat_state, checkpoint_id=ckpt_dir) + dist.barrier() + + if rank == 0: + logger.info("DCP save completed to %s", ckpt_dir) + + import copy + + # --- Reference: fresh optimizer, load via deepcopy (no serialization). + optim_ref, mp_ref, ap_ref = make_optimizer(True) + for i in range(num_muon): + mp_ref[i].data = mp_off[i].data.clone() + for i in range(num_adamw): + ap_ref[i].data = ap_off[i].data.clone() + optim_ref.load_state_dict(copy.deepcopy(sd_off)) + + # --- Resumed: fresh optimizer, load via DCP. + optim_resumed, mp_resumed, ap_resumed = make_optimizer(True) + for i in range(num_muon): + mp_resumed[i].data = mp_off[i].data.clone() + for i in range(num_adamw): + ap_resumed[i].data = ap_off[i].data.clone() + + flat_target = {k: torch.zeros_like(v) for k, v in flat_state.items()} + dcp.load(flat_target, checkpoint_id=ckpt_dir) + dist.barrier() + + sd_loaded = copy.deepcopy(sd_off) + for param_idx, param_state in sd_loaded["state"].items(): + for key in list(param_state.keys()): + flat_key = f"state.{param_idx}.{key}" + if flat_key in flat_target: + param_state[key] = flat_target[flat_key] + optim_resumed.load_state_dict(sd_loaded) + + if rank == 0: + logger.info("Both optimizers loaded, starting comparison steps") + + finally: + dist.barrier() + if rank == 0: + shutil.rmtree(ckpt_dir, ignore_errors=True) + + # Second half: reference continues, resumed uses loaded state. + for step_idx in range(num_steps, num_steps * 2): + for i in range(num_muon): + g = all_grads_muon[step_idx][i] + mp_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) + mp_resumed[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) + for i in range(num_adamw): + g = all_grads_adamw[step_idx][i] + ap_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) + ap_resumed[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) + optim_ref.step() + optim_resumed.step() + + # Compare final params: bitwise exact (DCP preserves DTensor identity). + for i in range(num_muon): + ref_full = mp_ref[i].data.full_tensor() + res_full = mp_resumed[i].data.full_tensor() + torch.testing.assert_close(ref_full, res_full, atol=0, rtol=0) + + for i in range(num_adamw): + ref_full = ap_ref[i].data.full_tensor() + res_full = ap_resumed[i].data.full_tensor() + torch.testing.assert_close(ref_full, res_full, atol=0, rtol=0) + + # Verify offload is active on the resumed optimizer. + for p in mp_resumed: + state = optim_resumed.state[p] + if "momentum_buffer" in state: + buf = state["momentum_buffer"] + local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf + assert local_buf.untyped_storage().size() == 0, ( + "Resumed optimizer should have offloaded state after step()") + + set_ns_compile(True) + if rank == 0: + logger.info("PASSED: test_state_dict_save_load") + + +def test_checkpoint_memory(rank, world_size): + """Verify GPU memory behaviour during state_dict/load_state_dict with offload. + + Checks: + 1. After step() + offload: GPU memory is low (state freed). + 2. During state_dict(): GPU memory temporarily rises (state reloaded). + 3. After state_dict() + next step(): GPU memory returns to low (re-offloaded). + 4. After load_state_dict(): GPU memory is low (state offloaded). + """ + import copy + + from optimizer.muon import Muon + from optimizer.newton_schulz import set_ns_compile + + set_ns_compile(False) + torch.manual_seed(42) + + mesh = _make_mesh(world_size) + + dim0, dim1 = 512, 1024 + num_params = 8 + + params, names = [], [] + for i in range(num_params): + full = torch.randn(dim0, dim1, device="cuda") + dt = distribute_tensor(full, mesh, [Shard(0)]) + p = torch.nn.Parameter(dt) + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + params.append(p) + names.append(f"layer.{i}.weight") + + param_groups = [{ + "params": params, + "names": names, + "use_muon": True, + "lr": 0.02, + "weight_decay": 0.01, + "momentum": 0.95, + "nesterov": True, + "ns_steps": 5, + "none_grad": False, + }] + optim = Muon(params=param_groups, + chunk_size=2, + warmup_step=1, + cpu_offload=True) + + # Step 1: run a step so offload initializes. + optim.step() + torch.cuda.synchronize() + + mem_after_step = torch.cuda.memory_allocated() + + # Calculate expected state size (momentum buffers, bf16). + state_bytes = 0 + for p in params: + state = optim.state[p] + if "momentum_buffer" in state: + buf = state["momentum_buffer"] + local = buf._local_tensor if isinstance(buf, DTensor) else buf + # Storage is freed, so use the tracked size. + state_bytes += optim._cpu_offload_pool._storage_nbytes[id(buf)] + + if rank == 0: + logger.info( + "After step (offloaded): GPU alloc=%.2f MB, " + "expected state size=%.2f MB", mem_after_step / 1024**2, + state_bytes / 1024**2) + + # Step 2: state_dict() should temporarily reload states to GPU. + sd = optim.state_dict() + torch.cuda.synchronize() + mem_during_state_dict = torch.cuda.memory_allocated() + + if rank == 0: + logger.info("After state_dict (states on GPU): GPU alloc=%.2f MB", + mem_during_state_dict / 1024**2) + + # States should now be on GPU — memory should be higher. + assert mem_during_state_dict > mem_after_step, ( + f"state_dict() should reload states to GPU. " + f"After step: {mem_after_step / 1024**2:.2f} MB, " + f"After state_dict: {mem_during_state_dict / 1024**2:.2f} MB") + + mem_increase = mem_during_state_dict - mem_after_step + if rank == 0: + logger.info( + "Memory increase from reload: %.2f MB " + "(expected ~%.2f MB)", mem_increase / 1024**2, + state_bytes / 1024**2) + + # Step 3: next step() should re-offload and free GPU memory. + # Delete sd to free the cloned state tensors from state_dict(). + sd_for_load = copy.deepcopy(sd) + del sd + torch.cuda.empty_cache() + + for p in params: + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + optim.step() + torch.cuda.synchronize() + + mem_after_next_step = torch.cuda.memory_allocated() + + if rank == 0: + logger.info("After next step (re-offloaded): GPU alloc=%.2f MB", + mem_after_next_step / 1024**2) + + # Allow 4 MB tolerance for CUDA allocator fragmentation. + assert mem_after_next_step <= mem_after_step + 4 * 1024 * 1024, ( + f"Memory should return to offloaded level after step(). " + f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " + f"got {mem_after_next_step / 1024**2:.2f} MB") + + # Step 4: load_state_dict() should end with states offloaded. + optim.load_state_dict(sd_for_load) + torch.cuda.synchronize() + + mem_after_load = torch.cuda.memory_allocated() + + if rank == 0: + logger.info("After load_state_dict (offloaded): GPU alloc=%.2f MB", + mem_after_load / 1024**2) + + # After load_state_dict, states should be offloaded again. + # Allow some tolerance for PyTorch allocator fragmentation. + assert mem_after_load <= mem_after_step + 1024 * 1024, ( + f"load_state_dict should offload states. " + f"Expected ~{mem_after_step / 1024**2:.2f} MB, " + f"got {mem_after_load / 1024**2:.2f} MB") + + # Verify CPU pinned memory is allocated for the new pool. + pool = optim._cpu_offload_pool + assert pool._initialized, "Offload pool should be initialized after load" + for grp in pool._groups.values(): + assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned" + + # Step 5: verify the loaded optimizer can still step correctly. + for p in params: + p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), + mesh, [Shard(0)]) + optim.step() + torch.cuda.synchronize() + + mem_final = torch.cuda.memory_allocated() + assert mem_final <= mem_after_step + 4 * 1024 * 1024, ( + f"Final memory should be at offloaded level. " + f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " + f"got {mem_final / 1024**2:.2f} MB") + + set_ns_compile(True) + if rank == 0: + logger.info("PASSED: test_checkpoint_memory") + + +def main(): + rank, world_size = _setup() + + try: + test_correctness(rank, world_size) + test_memory(rank, world_size) + test_adamw_offload(rank, world_size) + test_memory_savings(rank, world_size) + test_leak(rank, world_size) + test_state_dict_save_load(rank, world_size) + test_checkpoint_memory(rank, world_size) + + if rank == 0: + logger.info("=" * 50) + logger.info("ALL CPU OFFLOAD TESTS PASSED") + logger.info("=" * 50) + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/torch-ext/optimizer/cpu_offload.py b/torch-ext/optimizer/cpu_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbb86a48bdebca1b59027339f4dfa532a990bc7 --- /dev/null +++ b/torch-ext/optimizer/cpu_offload.py @@ -0,0 +1,190 @@ +"""CPU offloading for optimizer states. + +Manages a pinned CPU memory pool and async CUDA streams to offload +optimizer state tensors (momentum buffers, Adam moments) to CPU between +optimizer steps, freeing GPU memory. + +All tracked tensors are packed into a single flat pinned CPU buffer +(per dtype). D2H and H2D copies are performed per-tensor directly +between individual GPU tensors and their slice of the CPU flat buffer +— no GPU staging buffer is allocated, so there is **no temporary GPU +memory spike** during offload or reload. + +Individual tensor storages are freed after offload via +``untyped_storage().resize_(0)``, preserving tensor identity so +downstream caches remain valid. +""" + +import logging +from collections import defaultdict + +import torch +from torch.distributed.tensor import DTensor + +logger = logging.getLogger(__name__) + + +class CPUOffloadPool: + """Pinned CPU memory pool for async optimizer state offloading. + + Tracked tensors are grouped by dtype. Each group gets a single flat + pinned CPU buffer. D2H / H2D copies are per-tensor (into slices of + the flat buffer) to avoid allocating a GPU staging buffer. + """ + + def __init__(self): + self._managed: list[torch.Tensor] = [] + self._storage_nbytes: dict[int, int] = {} # id(t) → bytes + + # Per-dtype group: populated on first offload. + # dtype → dict with keys: + # "indices" : list[int] managed-list indices + # "offsets" : list[tuple[int,int]] (start, numel) in flat buf + # "total" : int total numel + # "cpu_flat" : Tensor pinned CPU buffer + self._groups: dict[torch.dtype, dict] = {} + + self._offload_stream: torch.cuda.Stream | None = None + self._device: torch.device | None = None + self._initialized: bool = False + self._logged: bool = False + + # ------------------------------------------------------------------ + @staticmethod + def _local(t: torch.Tensor) -> torch.Tensor: + """Unwrap DTensor to its local CUDA tensor.""" + return t._local_tensor if isinstance(t, DTensor) else t + + def _ensure_stream(self): + if self._offload_stream is None: + self._offload_stream = torch.cuda.Stream(device=self._device) + + # ------------------------------------------------------------------ + def track(self, tensor: torch.Tensor): + """Register a GPU tensor for CPU offloading. Idempotent.""" + tid = id(tensor) + if tid in self._storage_nbytes: + return + local = self._local(tensor) + if self._device is None: + self._device = local.device + self._storage_nbytes[tid] = local.untyped_storage().size() + self._managed.append(tensor) + + # ------------------------------------------------------------------ + def _init_buffers(self): + """Build per-dtype flat buffers on first offload.""" + # Group managed tensors by dtype. + dtype_map: dict[torch.dtype, list[tuple[int, int]]] = defaultdict(list) + for idx, t in enumerate(self._managed): + local = self._local(t) + dtype_map[local.dtype].append((idx, local.numel())) + + total_cpu_bytes = 0 + for dtype, entries in dtype_map.items(): + offsets: list[tuple[int, int]] = [] + indices: list[int] = [] + off = 0 + for idx, n in entries: + indices.append(idx) + offsets.append((off, n)) + off += n + cpu_flat = torch.empty(off, + dtype=dtype, + device="cpu", + pin_memory=True) + self._groups[dtype] = { + "indices": indices, + "offsets": offsets, + "total": off, + "cpu_flat": cpu_flat, + } + total_cpu_bytes += off * cpu_flat.element_size() + + self._initialized = True + logger.info( + "[CPUOffload] Pool initialized: %d tensors, %d dtype group(s), " + "%.2f MB pinned CPU memory", + len(self._managed), + len(self._groups), + total_cpu_bytes / (1024**2), + ) + + # ------------------------------------------------------------------ + def offload(self): + """Per-tensor async D2H into CPU flat buffer, then free GPU storage.""" + if not self._managed: + return + if not self._initialized: + self._init_buffers() + self._ensure_stream() + + # Offload stream waits for compute to finish. + compute_event = torch.cuda.current_stream(self._device).record_event() + self._offload_stream.wait_event(compute_event) + + offloaded_bytes = 0 + + # Per-tensor D2H copies directly into CPU flat buffer slices. + # No GPU staging buffer → no temporary GPU memory spike. + with torch.cuda.stream(self._offload_stream): + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + cpu_flat[off:off + n].copy_(local.reshape(-1), + non_blocking=True) + + offloaded_bytes += grp["total"] * cpu_flat.element_size() + + # Wait for all D2H copies to land, then free GPU storage. + self._offload_stream.synchronize() + for t in self._managed: + self._local(t).untyped_storage().resize_(0) + + if not self._logged: + logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)", + offloaded_bytes / (1024**2)) + + # ------------------------------------------------------------------ + def reload(self): + """Per-tensor H2D from CPU flat buffer on the default stream. + + Runs on the current (default) CUDA stream to avoid stream + interaction issues with the parallel Muon pipeline. Since + pinned CPU memory is the source, the copies overlap with + GPU idle time between steps. + """ + if not self._managed or not self._initialized: + return + + reloaded_bytes = 0 + + # Re-allocate all GPU storages first. + for t in self._managed: + local = self._local(t) + local.untyped_storage().resize_(self._storage_nbytes[id(t)]) + + # Per-tensor H2D copies from CPU flat buffer slices. + # non_blocking=True with pinned source allows DMA overlap. + for dtype, grp in self._groups.items(): + indices = grp["indices"] + offsets = grp["offsets"] + cpu_flat = grp["cpu_flat"] + + for i, mgd_idx in enumerate(indices): + local = self._local(self._managed[mgd_idx]) + off, n = offsets[i] + local.reshape(-1).copy_(cpu_flat[off:off + n], + non_blocking=True) + + reloaded_bytes += grp["total"] * cpu_flat.element_size() + + if not self._logged: + logger.info("[CPUOffload] Reloaded %.2f MB (CPU → GPU)", + reloaded_bytes / (1024**2)) + self._logged = True diff --git a/torch-ext/optimizer/muon.py b/torch-ext/optimizer/muon.py index a8d0751a7ef0c730eb1f0e63ef5e8f6a93a289d3..0115ae037bcf850a4547fe6e992e1e10a89905f7 100644 --- a/torch-ext/optimizer/muon.py +++ b/torch-ext/optimizer/muon.py @@ -12,6 +12,7 @@ from .adamw import step_adamw from .async_utils import run_pipeline from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, get_default_muon_param_groups, is_expert_param, update_p) +from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, @@ -206,7 +207,8 @@ class Muon(torch.optim.Optimizer): warmup_step=5, chunk_size=-1, use_distributed_muon=False, - expert_keys=None): + expert_keys=None, + cpu_offload=False): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -241,6 +243,9 @@ class Muon(torch.optim.Optimizer): self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.expert_keys = expert_keys + self.cpu_offload = cpu_offload + self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None + self._offload_initialized = False self._parallel_cache: dict[tuple[str, ...], dict] = {} self._expert_expand_cache: dict[tuple[int, ...], dict] = {} @@ -938,6 +943,33 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits, ) + def _register_states_for_offload(self): + """Register all optimizer state tensors with the CPU offload pool. + + Called once after the first step when states have been lazily created. + Offloads all param states (momentum buffers for Muon, moment1/moment2 + for AdamW) to free GPU memory between steps. + """ + pool = self._cpu_offload_pool + tracked = 0 + for group in self.param_groups: + for p in group["params"]: + if p not in self.state: + continue + state = self.state[p] + if group.get("use_muon", False): + if "momentum_buffer" in state: + pool.track(state["momentum_buffer"]) + tracked += 1 + else: + if "moment1" in state: + pool.track(state["moment1"]) + if "moment2" in state: + pool.track(state["moment2"]) + tracked += 1 + logger.info("[CPUOffload] Registered %d param states for offload", + tracked) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. @@ -955,6 +987,10 @@ class Muon(torch.optim.Optimizer): with torch.enable_grad(): loss = closure() + # H2D: reload optimizer states from CPU before computation. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + logger.debug("[Muon.step] expert_keys=%s, %d param groups", self.expert_keys, len(self.param_groups)) @@ -969,4 +1005,64 @@ class Muon(torch.optim.Optimizer): i, len(group["params"])) step_adamw(self.state, group) + # D2H: offload optimizer states to CPU after computation. + if self.cpu_offload: + if not self._offload_initialized: + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload() + return loss + + # ------------------------------------------------------------------ + # Checkpoint support for cpu_offload + # ------------------------------------------------------------------ + + def state_dict(self) -> dict: + """Return optimizer state dict, reloading offloaded states first. + + When ``cpu_offload=True``, optimizer state tensors have their GPU + storage freed (``resize_(0)``) between steps. We reload them, + snapshot the state dict, then re-offload so the optimizer stays + in the expected post-step state. The returned dict holds cloned + tensors so they remain valid after the re-offload frees the + originals' GPU storage. + """ + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + sd = super().state_dict() + if self.cpu_offload and self._offload_initialized: + # Clone state tensors so the returned dict survives re-offload + # (which frees GPU storage on the originals via resize_(0)). + for k in sd["state"]: + sd["state"][k] = { + sk: sv.clone() if isinstance(sv, torch.Tensor) else sv + for sk, sv in sd["state"][k].items() + } + self._cpu_offload_pool.offload() + return sd + + def load_state_dict(self, state_dict: dict) -> None: + """Load optimizer state dict, then offload states if needed. + + After ``super().load_state_dict()`` populates GPU tensors, we + re-register them with the offload pool and offload to CPU so the + optimizer is in the same post-step state (GPU storage freed). + """ + # If states were offloaded, reload first so storage sizes are + # correct for super().load_state_dict() to overwrite. + if self.cpu_offload and self._offload_initialized: + self._cpu_offload_pool.reload() + torch.cuda.current_stream().synchronize() + + super().load_state_dict(state_dict) + + if self.cpu_offload: + # Re-create the offload pool since state tensors may be new + # objects after load_state_dict. + self._cpu_offload_pool = CPUOffloadPool() + self._offload_initialized = False + self._register_states_for_offload() + self._offload_initialized = True + self._cpu_offload_pool.offload()