| """CPU offloading tests for optimizer states. |
| |
| Run with: |
| torchrun --nproc-per-node=8 --local-ranks-filter=0 test/test_cpu_offload.py |
| |
| Tests: |
| 1. Correctness: turn_on_cpu_offload() produces identical results to no offload |
| 2. Memory: GPU optimizer state storage is freed after offload |
| 3. AdamW: moment1/moment2 offloading works correctly |
| """ |
|
|
| import copy |
| import logging |
|
|
| import pytest |
| import torch |
| import torch.distributed as dist |
| from torch.distributed.tensor import DTensor, Shard, distribute_tensor |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") |
|
|
|
|
| def _setup(): |
| dist.init_process_group(backend="nccl") |
| rank = dist.get_rank() |
| torch.cuda.set_device(rank % torch.cuda.device_count()) |
| return rank, dist.get_world_size() |
|
|
|
|
| def _make_mesh(world_size): |
| return dist.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) |
|
|
|
|
| def test_correctness(rank, world_size): |
| """Verify that turn_on_cpu_offload() produces identical parameters as no offload.""" |
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
|
|
| mesh = _make_mesh(world_size) |
|
|
| dim0, dim1 = 64, 128 |
| num_params = 4 |
| num_steps = 3 |
|
|
| |
| full_params = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)] |
| full_grads = [ |
| [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)] |
| for _ in range(num_steps) |
| ] |
|
|
| def make_optimizer(cpu_offload): |
| params, names = [], [] |
| for i, fp in enumerate(full_params): |
| dt = distribute_tensor(fp.clone(), mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| params.append(p) |
| names.append(f"layer.{i}.weight") |
| param_groups = [ |
| { |
| "params": params, |
| "names": names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| } |
| ] |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| if cpu_offload: |
| optim.turn_on_cpu_offload() |
| return optim, params |
|
|
| optim_ref, params_ref = make_optimizer(False) |
| optim_off, params_off = make_optimizer(True) |
|
|
| for step_idx in range(num_steps): |
| for i in range(num_params): |
| g = full_grads[step_idx][i] |
| params_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
| params_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
|
|
| optim_ref.step() |
| optim_off.step() |
|
|
| for i in range(num_params): |
| ref_full = params_ref[i].data.full_tensor() |
| off_full = params_off[i].data.full_tensor() |
| torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0) |
|
|
| if rank == 0: |
| logger.info("Step %d: correctness OK", step_idx) |
|
|
| set_ns_compile(True) |
| if rank == 0: |
| logger.info("PASSED: test_correctness") |
|
|
|
|
| def test_memory(rank, world_size): |
| """Verify that GPU storage is freed after offload.""" |
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
|
|
| mesh = _make_mesh(world_size) |
|
|
| dim0, dim1 = 512, 1024 |
| num_params = 8 |
|
|
| params, names = [], [] |
| for i in range(num_params): |
| full = torch.randn(dim0, dim1, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| p.grad = distribute_tensor( |
| torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] |
| ) |
| params.append(p) |
| names.append(f"layer.{i}.weight") |
|
|
| param_groups = [ |
| { |
| "params": params, |
| "names": names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| } |
| ] |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| optim.turn_on_cpu_offload() |
|
|
| optim.step() |
| torch.cuda.synchronize() |
|
|
| |
| for p in params: |
| state = optim.state[p] |
| if "momentum_buffer" not in state: |
| continue |
| buf = state["momentum_buffer"] |
| local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf |
| assert local_buf.untyped_storage().size() == 0, ( |
| f"Expected freed GPU storage after offload, got " |
| f"{local_buf.untyped_storage().size()} bytes" |
| ) |
|
|
| |
| pool = optim._cpu_offload_pool |
| assert len(pool._managed) > 0, "No tensors tracked by CPU offload pool" |
| for grp in pool._groups.values(): |
| assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned memory" |
|
|
| |
| for p in params: |
| p.grad = distribute_tensor( |
| torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] |
| ) |
| optim.step() |
| torch.cuda.synchronize() |
|
|
| |
| for p in params: |
| state = optim.state[p] |
| if "momentum_buffer" not in state: |
| continue |
| buf = state["momentum_buffer"] |
| local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf |
| assert local_buf.untyped_storage().size() == 0 |
|
|
| set_ns_compile(True) |
| if rank == 0: |
| logger.info("PASSED: test_memory") |
|
|
|
|
| def test_adamw_offload(rank, world_size): |
| """Verify AdamW moment1/moment2 are offloaded correctly.""" |
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
|
|
| mesh = _make_mesh(world_size) |
|
|
| num_steps = 3 |
|
|
| |
| muon_params, muon_names = [], [] |
| adamw_params, adamw_names = [], [] |
|
|
| for i in range(4): |
| full = torch.randn(64, 128, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| muon_params.append(p) |
| muon_names.append(f"layer.{i}.weight") |
|
|
| for i in range(3): |
| full = torch.randn(128, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| adamw_params.append(p) |
| adamw_names.append(f"layer.{i}.bias") |
|
|
| |
| muon_grads = [ |
| [torch.randn(64, 128, device="cuda") for _ in range(4)] |
| for _ in range(num_steps) |
| ] |
| adamw_grads = [ |
| [torch.randn(128, device="cuda") for _ in range(3)] for _ in range(num_steps) |
| ] |
|
|
| def make_optimizer(cpu_offload): |
| mp = [ |
| torch.nn.Parameter( |
| distribute_tensor(p.data.full_tensor().clone(), mesh, [Shard(0)]) |
| ) |
| for p in muon_params |
| ] |
| ap = [ |
| torch.nn.Parameter( |
| distribute_tensor(p.data.full_tensor().clone(), mesh, [Shard(0)]) |
| ) |
| for p in adamw_params |
| ] |
| param_groups = [ |
| { |
| "params": mp, |
| "names": list(muon_names), |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| "adamw_betas": (0.9, 0.95), |
| "adamw_eps": 1e-8, |
| }, |
| { |
| "params": ap, |
| "use_muon": False, |
| "lr": 1e-3, |
| "weight_decay": 0.01, |
| "adamw_betas": (0.9, 0.95), |
| "adamw_eps": 1e-8, |
| }, |
| ] |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| if cpu_offload: |
| optim.turn_on_cpu_offload() |
| return optim, mp, ap |
|
|
| optim_ref, mp_ref, ap_ref = make_optimizer(False) |
| optim_off, mp_off, ap_off = make_optimizer(True) |
|
|
| for step_idx in range(num_steps): |
| for i in range(4): |
| g = muon_grads[step_idx][i] |
| mp_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
| mp_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
| for i in range(3): |
| g = adamw_grads[step_idx][i] |
| ap_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
| ap_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
|
|
| optim_ref.step() |
| optim_off.step() |
|
|
| |
| for i in range(4): |
| ref_full = mp_ref[i].data.full_tensor() |
| off_full = mp_off[i].data.full_tensor() |
| torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0) |
|
|
| |
| for i in range(3): |
| ref_full = ap_ref[i].data.full_tensor() |
| off_full = ap_off[i].data.full_tensor() |
| torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0) |
|
|
| if rank == 0: |
| logger.info("Step %d: AdamW offload correctness OK", step_idx) |
|
|
| |
| for p in ap_off: |
| state = optim_off.state[p] |
| for key in ("moment1", "moment2"): |
| if key not in state: |
| continue |
| t = state[key] |
| local_t = t._local_tensor if isinstance(t, DTensor) else t |
| assert local_t.untyped_storage().size() == 0, ( |
| f"AdamW {key} storage not freed after offload" |
| ) |
|
|
| set_ns_compile(True) |
| if rank == 0: |
| logger.info("PASSED: test_adamw_offload") |
|
|
|
|
| def test_memory_savings(rank, world_size): |
| """Measure actual GPU memory difference with and without offload.""" |
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| set_ns_compile(False) |
|
|
| mesh = _make_mesh(world_size) |
| dim0, dim1 = 1024, 2048 |
| num_params = 8 |
|
|
| def run_step(cpu_offload): |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
| torch.manual_seed(42) |
|
|
| params, names = [], [] |
| for i in range(num_params): |
| full = torch.randn(dim0, dim1, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| p.grad = distribute_tensor( |
| torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] |
| ) |
| params.append(p) |
| names.append(f"layer.{i}.weight") |
|
|
| param_groups = [ |
| { |
| "params": params, |
| "names": names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| } |
| ] |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| if cpu_offload: |
| optim.turn_on_cpu_offload() |
| optim.step() |
| torch.cuda.synchronize() |
|
|
| mem = torch.cuda.memory_allocated() |
| |
| del optim, params, param_groups |
| torch.cuda.empty_cache() |
| return mem |
|
|
| mem_no_offload = run_step(False) |
| mem_with_offload = run_step(True) |
|
|
| if rank == 0: |
| logger.info("Memory without offload: %.2f MB", mem_no_offload / 1024**2) |
| logger.info("Memory with offload: %.2f MB", mem_with_offload / 1024**2) |
| saved = mem_no_offload - mem_with_offload |
| logger.info("Memory saved: %.2f MB", saved / 1024**2) |
|
|
| assert mem_with_offload < mem_no_offload, ( |
| f"Expected memory reduction with CPU offload. " |
| f"Without: {mem_no_offload / 1024**2:.2f} MB, " |
| f"With: {mem_with_offload / 1024**2:.2f} MB" |
| ) |
|
|
| set_ns_compile(True) |
| if rank == 0: |
| logger.info("PASSED: test_memory_savings") |
|
|
|
|
| def test_toggle_correctness(rank, world_size): |
| """Verify toggling offload on/off between steps produces identical results.""" |
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
|
|
| mesh = _make_mesh(world_size) |
|
|
| dim0, dim1 = 64, 128 |
| num_params = 4 |
| num_steps = 6 |
|
|
| full_params = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)] |
| full_grads = [ |
| [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)] |
| for _ in range(num_steps) |
| ] |
|
|
| def make_optimizer(): |
| params, names = [], [] |
| for i, fp in enumerate(full_params): |
| dt = distribute_tensor(fp.clone(), mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| params.append(p) |
| names.append(f"layer.{i}.weight") |
| param_groups = [ |
| { |
| "params": params, |
| "names": names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| } |
| ] |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| return optim, params |
|
|
| |
| optim_ref, params_ref = make_optimizer() |
|
|
| |
| optim_toggle, params_toggle = make_optimizer() |
|
|
| for step_idx in range(num_steps): |
| |
| want_on = (step_idx // 2) % 2 == 0 |
| if want_on and not optim_toggle.cpu_offload: |
| optim_toggle.turn_on_cpu_offload() |
| elif not want_on and optim_toggle.cpu_offload: |
| optim_toggle.turn_off_cpu_offload() |
|
|
| for i in range(num_params): |
| g = full_grads[step_idx][i] |
| params_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
| params_toggle[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
|
|
| optim_ref.step() |
| optim_toggle.step() |
|
|
| for i in range(num_params): |
| ref_full = params_ref[i].data.full_tensor() |
| tog_full = params_toggle[i].data.full_tensor() |
| torch.testing.assert_close(ref_full, tog_full, atol=0, rtol=0) |
|
|
| if rank == 0: |
| logger.info( |
| "Step %d (offload=%s): toggle correctness OK", |
| step_idx, |
| optim_toggle.cpu_offload, |
| ) |
|
|
| set_ns_compile(True) |
| if rank == 0: |
| logger.info("PASSED: test_toggle_correctness") |
|
|
|
|
| def test_leak(rank, world_size): |
| """Run many iterations and verify no CPU/GPU memory leak.""" |
| import os |
|
|
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
|
|
| mesh = _make_mesh(world_size) |
|
|
| dim0, dim1 = 512, 1024 |
| num_params = 8 |
| num_steps = 50 |
|
|
| params, names = [], [] |
| for i in range(num_params): |
| full = torch.randn(dim0, dim1, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| params.append(p) |
| names.append(f"layer.{i}.weight") |
|
|
| param_groups = [ |
| { |
| "params": params, |
| "names": names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| } |
| ] |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| optim.turn_on_cpu_offload() |
|
|
| def get_cpu_rss_mb(): |
| """Get current process RSS in MB from /proc/self/statm.""" |
| with open("/proc/self/statm") as f: |
| pages = int(f.read().split()[1]) |
| return pages * os.sysconf("SC_PAGE_SIZE") / (1024**2) |
|
|
| gpu_after_warmup = None |
| cpu_after_warmup = None |
|
|
| for step_idx in range(num_steps): |
| for p in params: |
| p.grad = distribute_tensor( |
| torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] |
| ) |
|
|
| optim.step() |
| torch.cuda.synchronize() |
|
|
| gpu_mem = torch.cuda.memory_allocated() |
| cpu_mem = get_cpu_rss_mb() |
|
|
| |
| |
| if step_idx == 2: |
| gpu_after_warmup = gpu_mem |
| cpu_after_warmup = cpu_mem |
|
|
| if rank == 0 and step_idx % 10 == 0: |
| logger.info( |
| "Step %d: GPU alloc=%.2f MB, CPU RSS=%.2f MB", |
| step_idx, |
| gpu_mem / (1024**2), |
| cpu_mem, |
| ) |
|
|
| |
| torch.cuda.synchronize() |
| gpu_final = torch.cuda.memory_allocated() |
| cpu_final = get_cpu_rss_mb() |
|
|
| if rank == 0: |
| logger.info( |
| "After %d steps: GPU alloc=%.2f MB, CPU RSS=%.2f MB", |
| num_steps, |
| gpu_final / (1024**2), |
| cpu_final, |
| ) |
| logger.info( |
| "Warmup baseline: GPU alloc=%.2f MB, CPU RSS=%.2f MB", |
| gpu_after_warmup / (1024**2), |
| cpu_after_warmup, |
| ) |
|
|
| |
| assert gpu_final <= gpu_after_warmup, ( |
| f"GPU memory leak detected! Warmup: {gpu_after_warmup / 1024**2:.2f} MB, " |
| f"Final: {gpu_final / 1024**2:.2f} MB" |
| ) |
|
|
| |
| |
| cpu_growth = cpu_final - cpu_after_warmup |
| assert cpu_growth < 50, ( |
| f"CPU memory leak detected! Growth: {cpu_growth:.2f} MB over " |
| f"{num_steps - 2} steps (warmup={cpu_after_warmup:.2f} MB, " |
| f"final={cpu_final:.2f} MB)" |
| ) |
|
|
| set_ns_compile(True) |
| if rank == 0: |
| logger.info("PASSED: test_leak (GPU stable, CPU growth=%.2f MB)", cpu_growth) |
|
|
|
|
| def test_state_dict_save_load(rank, world_size): |
| """Verify state_dict() works after offload and load_state_dict() resumes correctly. |
| |
| Uses torch.distributed.checkpoint (DCP) for serialization, matching |
| the actual LLM training checkpoint flow. DCP handles DTensors natively |
| so the roundtrip is bitwise exact. |
| """ |
| import shutil |
| import tempfile |
|
|
| import torch.distributed.checkpoint as dcp |
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
|
|
| mesh = _make_mesh(world_size) |
|
|
| dim0, dim1 = 64, 128 |
| num_muon = 4 |
| num_adamw = 3 |
| num_steps = 3 |
|
|
| |
| muon_init = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon)] |
| adamw_init = [torch.randn(dim1, device="cuda") for _ in range(num_adamw)] |
| all_grads_muon = [ |
| [torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon)] |
| for _ in range(num_steps * 2) |
| ] |
| all_grads_adamw = [ |
| [torch.randn(dim1, device="cuda") for _ in range(num_adamw)] |
| for _ in range(num_steps * 2) |
| ] |
|
|
| def make_optimizer(cpu_offload): |
| mp = [ |
| torch.nn.Parameter( |
| distribute_tensor(muon_init[i].clone(), mesh, [Shard(0)]) |
| ) |
| for i in range(num_muon) |
| ] |
| ap = [ |
| torch.nn.Parameter( |
| distribute_tensor(adamw_init[i].clone(), mesh, [Shard(0)]) |
| ) |
| for i in range(num_adamw) |
| ] |
| param_groups = [ |
| { |
| "params": mp, |
| "names": [f"layer.{i}.weight" for i in range(num_muon)], |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| "adamw_betas": (0.9, 0.95), |
| "adamw_eps": 1e-8, |
| }, |
| { |
| "params": ap, |
| "use_muon": False, |
| "lr": 1e-3, |
| "weight_decay": 0.01, |
| "adamw_betas": (0.9, 0.95), |
| "adamw_eps": 1e-8, |
| }, |
| ] |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| if cpu_offload: |
| optim.turn_on_cpu_offload() |
| return optim, mp, ap |
|
|
| |
| |
| |
| optim_off, mp_off, ap_off = make_optimizer(True) |
|
|
| for step_idx in range(num_steps): |
| for i in range(num_muon): |
| mp_off[i].grad = distribute_tensor( |
| all_grads_muon[step_idx][i].clone(), mesh, [Shard(0)] |
| ) |
| for i in range(num_adamw): |
| ap_off[i].grad = distribute_tensor( |
| all_grads_adamw[step_idx][i].clone(), mesh, [Shard(0)] |
| ) |
| optim_off.step() |
|
|
| with pytest.raises( |
| RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint save" |
| ): |
| optim_off.state_dict() |
|
|
| optim_off.turn_off_cpu_offload() |
| sd_off = optim_off.state_dict() |
|
|
| |
| for param_states in sd_off["state"].values(): |
| for key, val in param_states.items(): |
| if isinstance(val, torch.Tensor) and val.is_floating_point(): |
| assert val.untyped_storage().size() > 0, ( |
| f"state_dict() returned empty storage for key '{key}' — " |
| f"offload reload is broken" |
| ) |
|
|
| if rank == 0: |
| logger.info("state_dict() contains valid (non-empty) tensors") |
|
|
| |
| |
| flat_state = {} |
| for param_idx, param_state in sd_off["state"].items(): |
| for key, val in param_state.items(): |
| if isinstance(val, torch.Tensor): |
| flat_state[f"state.{param_idx}.{key}"] = val |
|
|
| |
| if rank == 0: |
| ckpt_dir = tempfile.mkdtemp(prefix="cpu_offload_test_") |
| else: |
| ckpt_dir = "" |
| ckpt_dir_list = [ckpt_dir] |
| dist.broadcast_object_list(ckpt_dir_list, src=0) |
| ckpt_dir = ckpt_dir_list[0] |
| try: |
| dcp.save(flat_state, checkpoint_id=ckpt_dir) |
| dist.barrier() |
|
|
| if rank == 0: |
| logger.info("DCP save completed to %s", ckpt_dir) |
|
|
| |
| optim_ref, mp_ref, ap_ref = make_optimizer(True) |
| for i in range(num_muon): |
| mp_ref[i].data = mp_off[i].data.clone() |
| for i in range(num_adamw): |
| ap_ref[i].data = ap_off[i].data.clone() |
| with pytest.raises( |
| RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load" |
| ): |
| optim_ref.load_state_dict(copy.deepcopy(sd_off)) |
| optim_ref.turn_off_cpu_offload() |
| optim_ref.load_state_dict(copy.deepcopy(sd_off)) |
| optim_ref.turn_on_cpu_offload() |
|
|
| |
| optim_resumed, mp_resumed, ap_resumed = make_optimizer(True) |
| for i in range(num_muon): |
| mp_resumed[i].data = mp_off[i].data.clone() |
| for i in range(num_adamw): |
| ap_resumed[i].data = ap_off[i].data.clone() |
|
|
| flat_target = {k: torch.zeros_like(v) for k, v in flat_state.items()} |
| dcp.load(flat_target, checkpoint_id=ckpt_dir) |
| dist.barrier() |
|
|
| sd_loaded = copy.deepcopy(sd_off) |
| for param_idx, param_state in sd_loaded["state"].items(): |
| for key in list(param_state.keys()): |
| flat_key = f"state.{param_idx}.{key}" |
| if flat_key in flat_target: |
| param_state[key] = flat_target[flat_key] |
| with pytest.raises( |
| RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load" |
| ): |
| optim_resumed.load_state_dict(copy.deepcopy(sd_loaded)) |
| optim_resumed.turn_off_cpu_offload() |
| optim_resumed.load_state_dict(sd_loaded) |
| optim_resumed.turn_on_cpu_offload() |
|
|
| if rank == 0: |
| logger.info("Both optimizers loaded, starting comparison steps") |
|
|
| finally: |
| dist.barrier() |
| if rank == 0: |
| shutil.rmtree(ckpt_dir, ignore_errors=True) |
|
|
| |
| for step_idx in range(num_steps, num_steps * 2): |
| for i in range(num_muon): |
| g = all_grads_muon[step_idx][i] |
| mp_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
| mp_resumed[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
| for i in range(num_adamw): |
| g = all_grads_adamw[step_idx][i] |
| ap_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
| ap_resumed[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) |
| optim_ref.step() |
| optim_resumed.step() |
|
|
| |
| for i in range(num_muon): |
| ref_full = mp_ref[i].data.full_tensor() |
| res_full = mp_resumed[i].data.full_tensor() |
| torch.testing.assert_close(ref_full, res_full, atol=0, rtol=0) |
|
|
| for i in range(num_adamw): |
| ref_full = ap_ref[i].data.full_tensor() |
| res_full = ap_resumed[i].data.full_tensor() |
| torch.testing.assert_close(ref_full, res_full, atol=0, rtol=0) |
|
|
| |
| for p in mp_resumed: |
| state = optim_resumed.state[p] |
| if "momentum_buffer" in state: |
| buf = state["momentum_buffer"] |
| local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf |
| assert local_buf.untyped_storage().size() == 0, ( |
| "Resumed optimizer should have offloaded state after step()" |
| ) |
|
|
| set_ns_compile(True) |
| if rank == 0: |
| logger.info("PASSED: test_state_dict_save_load") |
|
|
|
|
| def test_checkpoint_memory(rank, world_size): |
| """Verify checkpoint APIs require offload to be disabled explicitly.""" |
| from optimizer.muon import Muon |
| from optimizer.newton_schulz import set_ns_compile |
|
|
| set_ns_compile(False) |
| torch.manual_seed(42) |
|
|
| mesh = _make_mesh(world_size) |
|
|
| dim0, dim1 = 512, 1024 |
| num_params = 8 |
|
|
| params, names = [], [] |
| for i in range(num_params): |
| full = torch.randn(dim0, dim1, device="cuda") |
| dt = distribute_tensor(full, mesh, [Shard(0)]) |
| p = torch.nn.Parameter(dt) |
| p.grad = distribute_tensor( |
| torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] |
| ) |
| params.append(p) |
| names.append(f"layer.{i}.weight") |
|
|
| param_groups = [ |
| { |
| "params": params, |
| "names": names, |
| "use_muon": True, |
| "lr": 0.02, |
| "weight_decay": 0.01, |
| "momentum": 0.95, |
| "nesterov": True, |
| "ns_steps": 5, |
| "none_grad": False, |
| } |
| ] |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) |
| optim.turn_on_cpu_offload() |
|
|
| |
| optim.step() |
| torch.cuda.synchronize() |
|
|
| mem_after_step = torch.cuda.memory_allocated() |
|
|
| |
| state_bytes = 0 |
| for p in params: |
| state = optim.state[p] |
| if "momentum_buffer" in state: |
| buf = state["momentum_buffer"] |
| local = buf._local_tensor if isinstance(buf, DTensor) else buf |
| |
| state_bytes += optim._cpu_offload_pool._storage_nbytes[id(buf)] |
|
|
| if rank == 0: |
| logger.info( |
| "After step (offloaded): GPU alloc=%.2f MB, expected state size=%.2f MB", |
| mem_after_step / 1024**2, |
| state_bytes / 1024**2, |
| ) |
|
|
| with pytest.raises( |
| RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint save" |
| ): |
| optim.state_dict() |
|
|
| optim.turn_off_cpu_offload() |
| torch.cuda.synchronize() |
| mem_after_turn_off = torch.cuda.memory_allocated() |
| sd_for_load = copy.deepcopy(optim.state_dict()) |
|
|
| if rank == 0: |
| logger.info( |
| "After turn_off_cpu_offload: GPU alloc=%.2f MB", |
| mem_after_turn_off / 1024**2, |
| ) |
|
|
| assert mem_after_turn_off > mem_after_step, ( |
| f"turn_off_cpu_offload() should reload states to GPU. " |
| f"After offload: {mem_after_step / 1024**2:.2f} MB, " |
| f"After turn_off: {mem_after_turn_off / 1024**2:.2f} MB" |
| ) |
|
|
| optim.turn_on_cpu_offload() |
| torch.cuda.synchronize() |
| mem_after_turn_on = torch.cuda.memory_allocated() |
|
|
| if rank == 0: |
| logger.info( |
| "After turn_on_cpu_offload: GPU alloc=%.2f MB", mem_after_turn_on / 1024**2 |
| ) |
|
|
| assert mem_after_turn_on <= mem_after_step + 4 * 1024 * 1024, ( |
| f"turn_on_cpu_offload() should return memory to offloaded level. " |
| f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " |
| f"got {mem_after_turn_on / 1024**2:.2f} MB" |
| ) |
|
|
| for p in params: |
| p.grad = distribute_tensor( |
| torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] |
| ) |
| optim.step() |
| torch.cuda.synchronize() |
|
|
| mem_after_next_step = torch.cuda.memory_allocated() |
|
|
| if rank == 0: |
| logger.info( |
| "After next step (re-offloaded): GPU alloc=%.2f MB", |
| mem_after_next_step / 1024**2, |
| ) |
|
|
| |
| assert mem_after_next_step <= mem_after_step + 4 * 1024 * 1024, ( |
| f"Memory should return to offloaded level after step(). " |
| f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " |
| f"got {mem_after_next_step / 1024**2:.2f} MB" |
| ) |
|
|
| with pytest.raises( |
| RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load" |
| ): |
| optim.load_state_dict(copy.deepcopy(sd_for_load)) |
|
|
| optim.turn_off_cpu_offload() |
| optim.load_state_dict(sd_for_load) |
| torch.cuda.synchronize() |
|
|
| mem_after_load = torch.cuda.memory_allocated() |
|
|
| if rank == 0: |
| logger.info( |
| "After load_state_dict with offload disabled: GPU alloc=%.2f MB", |
| mem_after_load / 1024**2, |
| ) |
|
|
| assert mem_after_load >= mem_after_turn_off, ( |
| "Loaded optimizer state should stay on GPU while offload is disabled" |
| ) |
|
|
| optim.turn_on_cpu_offload() |
| torch.cuda.synchronize() |
|
|
| pool = optim._cpu_offload_pool |
| assert pool._initialized, ( |
| "Offload pool should be initialized after re-enabling offload" |
| ) |
| for grp in pool._groups.values(): |
| assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned" |
|
|
| |
| for p in params: |
| p.grad = distribute_tensor( |
| torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)] |
| ) |
| optim.step() |
| torch.cuda.synchronize() |
|
|
| mem_final = torch.cuda.memory_allocated() |
| assert mem_final <= mem_after_step + 4 * 1024 * 1024, ( |
| f"Final memory should be at offloaded level. " |
| f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " |
| f"got {mem_final / 1024**2:.2f} MB" |
| ) |
|
|
| set_ns_compile(True) |
| if rank == 0: |
| logger.info("PASSED: test_checkpoint_memory") |
|
|
|
|
| def main(): |
| rank, world_size = _setup() |
|
|
| try: |
| test_correctness(rank, world_size) |
| test_memory(rank, world_size) |
| test_adamw_offload(rank, world_size) |
| test_memory_savings(rank, world_size) |
| test_toggle_correctness(rank, world_size) |
| test_leak(rank, world_size) |
| test_state_dict_save_load(rank, world_size) |
| test_checkpoint_memory(rank, world_size) |
|
|
| if rank == 0: |
| logger.info("=" * 50) |
| logger.info("ALL CPU OFFLOAD TESTS PASSED") |
| logger.info("=" * 50) |
| finally: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|