Kernels
optimizer / test /test_cpu_offload.py
dongseokmotif's picture
feat: extend QK-Clip to support MLA (MuonClip Algorithm 1) [skip-build] (#28)
e8e2c81 unverified
"""CPU offloading tests for optimizer states.
Run with:
torchrun --nproc-per-node=8 --local-ranks-filter=0 test/test_cpu_offload.py
Tests:
1. Correctness: turn_on_cpu_offload() produces identical results to no offload
2. Memory: GPU optimizer state storage is freed after offload
3. AdamW: moment1/moment2 offloading works correctly
"""
import copy
import logging
import pytest
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Shard, distribute_tensor
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
def _setup():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
torch.cuda.set_device(rank % torch.cuda.device_count())
return rank, dist.get_world_size()
def _make_mesh(world_size):
return dist.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",))
def test_correctness(rank, world_size):
"""Verify that turn_on_cpu_offload() produces identical parameters as no offload."""
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
set_ns_compile(False)
torch.manual_seed(42)
mesh = _make_mesh(world_size)
dim0, dim1 = 64, 128
num_params = 4
num_steps = 3
# Pre-generate all data on all ranks (same seed → same values).
full_params = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)]
full_grads = [
[torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)]
for _ in range(num_steps)
]
def make_optimizer(cpu_offload):
params, names = [], []
for i, fp in enumerate(full_params):
dt = distribute_tensor(fp.clone(), mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
params.append(p)
names.append(f"layer.{i}.weight")
param_groups = [
{
"params": params,
"names": names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}
]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
if cpu_offload:
optim.turn_on_cpu_offload()
return optim, params
optim_ref, params_ref = make_optimizer(False)
optim_off, params_off = make_optimizer(True)
for step_idx in range(num_steps):
for i in range(num_params):
g = full_grads[step_idx][i]
params_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
params_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
optim_ref.step()
optim_off.step()
for i in range(num_params):
ref_full = params_ref[i].data.full_tensor()
off_full = params_off[i].data.full_tensor()
torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0)
if rank == 0:
logger.info("Step %d: correctness OK", step_idx)
set_ns_compile(True)
if rank == 0:
logger.info("PASSED: test_correctness")
def test_memory(rank, world_size):
"""Verify that GPU storage is freed after offload."""
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
set_ns_compile(False)
torch.manual_seed(42)
mesh = _make_mesh(world_size)
dim0, dim1 = 512, 1024
num_params = 8
params, names = [], []
for i in range(num_params):
full = torch.randn(dim0, dim1, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
p.grad = distribute_tensor(
torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
)
params.append(p)
names.append(f"layer.{i}.weight")
param_groups = [
{
"params": params,
"names": names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}
]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
optim.turn_on_cpu_offload()
optim.step()
torch.cuda.synchronize()
# After step + offload, all momentum buffer GPU storage should be freed.
for p in params:
state = optim.state[p]
if "momentum_buffer" not in state:
continue
buf = state["momentum_buffer"]
local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf
assert local_buf.untyped_storage().size() == 0, (
f"Expected freed GPU storage after offload, got "
f"{local_buf.untyped_storage().size()} bytes"
)
# Verify CPU pool has pinned buffers.
pool = optim._cpu_offload_pool
assert len(pool._managed) > 0, "No tensors tracked by CPU offload pool"
for grp in pool._groups.values():
assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned memory"
# Run another step to verify reload + compute + offload cycle works.
for p in params:
p.grad = distribute_tensor(
torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
)
optim.step()
torch.cuda.synchronize()
# Storage should be freed again after second step.
for p in params:
state = optim.state[p]
if "momentum_buffer" not in state:
continue
buf = state["momentum_buffer"]
local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf
assert local_buf.untyped_storage().size() == 0
set_ns_compile(True)
if rank == 0:
logger.info("PASSED: test_memory")
def test_adamw_offload(rank, world_size):
"""Verify AdamW moment1/moment2 are offloaded correctly."""
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
set_ns_compile(False)
torch.manual_seed(42)
mesh = _make_mesh(world_size)
num_steps = 3
# Create both Muon (2D) and AdamW (1D) params.
muon_params, muon_names = [], []
adamw_params, adamw_names = [], []
for i in range(4):
full = torch.randn(64, 128, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
muon_params.append(p)
muon_names.append(f"layer.{i}.weight")
for i in range(3):
full = torch.randn(128, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
adamw_params.append(p)
adamw_names.append(f"layer.{i}.bias")
# Pre-generate grads.
muon_grads = [
[torch.randn(64, 128, device="cuda") for _ in range(4)]
for _ in range(num_steps)
]
adamw_grads = [
[torch.randn(128, device="cuda") for _ in range(3)] for _ in range(num_steps)
]
def make_optimizer(cpu_offload):
mp = [
torch.nn.Parameter(
distribute_tensor(p.data.full_tensor().clone(), mesh, [Shard(0)])
)
for p in muon_params
]
ap = [
torch.nn.Parameter(
distribute_tensor(p.data.full_tensor().clone(), mesh, [Shard(0)])
)
for p in adamw_params
]
param_groups = [
{
"params": mp,
"names": list(muon_names),
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
"adamw_betas": (0.9, 0.95),
"adamw_eps": 1e-8,
},
{
"params": ap,
"use_muon": False,
"lr": 1e-3,
"weight_decay": 0.01,
"adamw_betas": (0.9, 0.95),
"adamw_eps": 1e-8,
},
]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
if cpu_offload:
optim.turn_on_cpu_offload()
return optim, mp, ap
optim_ref, mp_ref, ap_ref = make_optimizer(False)
optim_off, mp_off, ap_off = make_optimizer(True)
for step_idx in range(num_steps):
for i in range(4):
g = muon_grads[step_idx][i]
mp_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
mp_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
for i in range(3):
g = adamw_grads[step_idx][i]
ap_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
ap_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
optim_ref.step()
optim_off.step()
# Compare Muon params.
for i in range(4):
ref_full = mp_ref[i].data.full_tensor()
off_full = mp_off[i].data.full_tensor()
torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0)
# Compare AdamW params.
for i in range(3):
ref_full = ap_ref[i].data.full_tensor()
off_full = ap_off[i].data.full_tensor()
torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0)
if rank == 0:
logger.info("Step %d: AdamW offload correctness OK", step_idx)
# Verify AdamW states are offloaded.
for p in ap_off:
state = optim_off.state[p]
for key in ("moment1", "moment2"):
if key not in state:
continue
t = state[key]
local_t = t._local_tensor if isinstance(t, DTensor) else t
assert local_t.untyped_storage().size() == 0, (
f"AdamW {key} storage not freed after offload"
)
set_ns_compile(True)
if rank == 0:
logger.info("PASSED: test_adamw_offload")
def test_memory_savings(rank, world_size):
"""Measure actual GPU memory difference with and without offload."""
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
set_ns_compile(False)
mesh = _make_mesh(world_size)
dim0, dim1 = 1024, 2048
num_params = 8
def run_step(cpu_offload):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.manual_seed(42)
params, names = [], []
for i in range(num_params):
full = torch.randn(dim0, dim1, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
p.grad = distribute_tensor(
torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
)
params.append(p)
names.append(f"layer.{i}.weight")
param_groups = [
{
"params": params,
"names": names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}
]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
if cpu_offload:
optim.turn_on_cpu_offload()
optim.step()
torch.cuda.synchronize()
mem = torch.cuda.memory_allocated()
# Clean up to avoid interference.
del optim, params, param_groups
torch.cuda.empty_cache()
return mem
mem_no_offload = run_step(False)
mem_with_offload = run_step(True)
if rank == 0:
logger.info("Memory without offload: %.2f MB", mem_no_offload / 1024**2)
logger.info("Memory with offload: %.2f MB", mem_with_offload / 1024**2)
saved = mem_no_offload - mem_with_offload
logger.info("Memory saved: %.2f MB", saved / 1024**2)
assert mem_with_offload < mem_no_offload, (
f"Expected memory reduction with CPU offload. "
f"Without: {mem_no_offload / 1024**2:.2f} MB, "
f"With: {mem_with_offload / 1024**2:.2f} MB"
)
set_ns_compile(True)
if rank == 0:
logger.info("PASSED: test_memory_savings")
def test_toggle_correctness(rank, world_size):
"""Verify toggling offload on/off between steps produces identical results."""
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
set_ns_compile(False)
torch.manual_seed(42)
mesh = _make_mesh(world_size)
dim0, dim1 = 64, 128
num_params = 4
num_steps = 6
full_params = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)]
full_grads = [
[torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)]
for _ in range(num_steps)
]
def make_optimizer():
params, names = [], []
for i, fp in enumerate(full_params):
dt = distribute_tensor(fp.clone(), mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
params.append(p)
names.append(f"layer.{i}.weight")
param_groups = [
{
"params": params,
"names": names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}
]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
return optim, params
# Reference: no offload at all.
optim_ref, params_ref = make_optimizer()
# Toggle: on → step → off → step → on → step ...
optim_toggle, params_toggle = make_optimizer()
for step_idx in range(num_steps):
# Toggle offload every 2 steps: on for [0,1], off for [2,3], on for [4,5].
want_on = (step_idx // 2) % 2 == 0
if want_on and not optim_toggle.cpu_offload:
optim_toggle.turn_on_cpu_offload()
elif not want_on and optim_toggle.cpu_offload:
optim_toggle.turn_off_cpu_offload()
for i in range(num_params):
g = full_grads[step_idx][i]
params_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
params_toggle[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
optim_ref.step()
optim_toggle.step()
for i in range(num_params):
ref_full = params_ref[i].data.full_tensor()
tog_full = params_toggle[i].data.full_tensor()
torch.testing.assert_close(ref_full, tog_full, atol=0, rtol=0)
if rank == 0:
logger.info(
"Step %d (offload=%s): toggle correctness OK",
step_idx,
optim_toggle.cpu_offload,
)
set_ns_compile(True)
if rank == 0:
logger.info("PASSED: test_toggle_correctness")
def test_leak(rank, world_size):
"""Run many iterations and verify no CPU/GPU memory leak."""
import os
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
set_ns_compile(False)
torch.manual_seed(42)
mesh = _make_mesh(world_size)
dim0, dim1 = 512, 1024
num_params = 8
num_steps = 50
params, names = [], []
for i in range(num_params):
full = torch.randn(dim0, dim1, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
params.append(p)
names.append(f"layer.{i}.weight")
param_groups = [
{
"params": params,
"names": names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}
]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
optim.turn_on_cpu_offload()
def get_cpu_rss_mb():
"""Get current process RSS in MB from /proc/self/statm."""
with open("/proc/self/statm") as f:
pages = int(f.read().split()[1])
return pages * os.sysconf("SC_PAGE_SIZE") / (1024**2)
gpu_after_warmup = None
cpu_after_warmup = None
for step_idx in range(num_steps):
for p in params:
p.grad = distribute_tensor(
torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
)
optim.step()
torch.cuda.synchronize()
gpu_mem = torch.cuda.memory_allocated()
cpu_mem = get_cpu_rss_mb()
# Record baseline after warmup (step 2 — first step creates states,
# second step does first full offload/reload cycle).
if step_idx == 2:
gpu_after_warmup = gpu_mem
cpu_after_warmup = cpu_mem
if rank == 0 and step_idx % 10 == 0:
logger.info(
"Step %d: GPU alloc=%.2f MB, CPU RSS=%.2f MB",
step_idx,
gpu_mem / (1024**2),
cpu_mem,
)
# Final measurements.
torch.cuda.synchronize()
gpu_final = torch.cuda.memory_allocated()
cpu_final = get_cpu_rss_mb()
if rank == 0:
logger.info(
"After %d steps: GPU alloc=%.2f MB, CPU RSS=%.2f MB",
num_steps,
gpu_final / (1024**2),
cpu_final,
)
logger.info(
"Warmup baseline: GPU alloc=%.2f MB, CPU RSS=%.2f MB",
gpu_after_warmup / (1024**2),
cpu_after_warmup,
)
# GPU memory should not grow beyond warmup baseline.
assert gpu_final <= gpu_after_warmup, (
f"GPU memory leak detected! Warmup: {gpu_after_warmup / 1024**2:.2f} MB, "
f"Final: {gpu_final / 1024**2:.2f} MB"
)
# CPU RSS should not grow more than 50 MB over warmup (allows for minor
# Python/CUDA runtime overhead but catches real leaks).
cpu_growth = cpu_final - cpu_after_warmup
assert cpu_growth < 50, (
f"CPU memory leak detected! Growth: {cpu_growth:.2f} MB over "
f"{num_steps - 2} steps (warmup={cpu_after_warmup:.2f} MB, "
f"final={cpu_final:.2f} MB)"
)
set_ns_compile(True)
if rank == 0:
logger.info("PASSED: test_leak (GPU stable, CPU growth=%.2f MB)", cpu_growth)
def test_state_dict_save_load(rank, world_size):
"""Verify state_dict() works after offload and load_state_dict() resumes correctly.
Uses torch.distributed.checkpoint (DCP) for serialization, matching
the actual LLM training checkpoint flow. DCP handles DTensors natively
so the roundtrip is bitwise exact.
"""
import shutil
import tempfile
import torch.distributed.checkpoint as dcp
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
set_ns_compile(False)
torch.manual_seed(42)
mesh = _make_mesh(world_size)
dim0, dim1 = 64, 128
num_muon = 4
num_adamw = 3
num_steps = 3
# Pre-generate all data.
muon_init = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon)]
adamw_init = [torch.randn(dim1, device="cuda") for _ in range(num_adamw)]
all_grads_muon = [
[torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon)]
for _ in range(num_steps * 2)
]
all_grads_adamw = [
[torch.randn(dim1, device="cuda") for _ in range(num_adamw)]
for _ in range(num_steps * 2)
]
def make_optimizer(cpu_offload):
mp = [
torch.nn.Parameter(
distribute_tensor(muon_init[i].clone(), mesh, [Shard(0)])
)
for i in range(num_muon)
]
ap = [
torch.nn.Parameter(
distribute_tensor(adamw_init[i].clone(), mesh, [Shard(0)])
)
for i in range(num_adamw)
]
param_groups = [
{
"params": mp,
"names": [f"layer.{i}.weight" for i in range(num_muon)],
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
"adamw_betas": (0.9, 0.95),
"adamw_eps": 1e-8,
},
{
"params": ap,
"use_muon": False,
"lr": 1e-3,
"weight_decay": 0.01,
"adamw_betas": (0.9, 0.95),
"adamw_eps": 1e-8,
},
]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
if cpu_offload:
optim.turn_on_cpu_offload()
return optim, mp, ap
# --- Run one optimizer for first half, save state, then create TWO
# fresh optimizers: ref loads via deepcopy, resumed loads via DCP.
# Both are fresh → same internal cache state → isolates DCP fidelity.
optim_off, mp_off, ap_off = make_optimizer(True)
for step_idx in range(num_steps):
for i in range(num_muon):
mp_off[i].grad = distribute_tensor(
all_grads_muon[step_idx][i].clone(), mesh, [Shard(0)]
)
for i in range(num_adamw):
ap_off[i].grad = distribute_tensor(
all_grads_adamw[step_idx][i].clone(), mesh, [Shard(0)]
)
optim_off.step()
with pytest.raises(
RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint save"
):
optim_off.state_dict()
optim_off.turn_off_cpu_offload()
sd_off = optim_off.state_dict()
# Verify state tensors are NOT empty in the state_dict.
for param_states in sd_off["state"].values():
for key, val in param_states.items():
if isinstance(val, torch.Tensor) and val.is_floating_point():
assert val.untyped_storage().size() > 0, (
f"state_dict() returned empty storage for key '{key}' — "
f"offload reload is broken"
)
if rank == 0:
logger.info("state_dict() contains valid (non-empty) tensors")
# Save state tensors via DCP (matches real LLM training checkpoint flow).
# Flatten state tensors with string keys for DCP compatibility.
flat_state = {}
for param_idx, param_state in sd_off["state"].items():
for key, val in param_state.items():
if isinstance(val, torch.Tensor):
flat_state[f"state.{param_idx}.{key}"] = val
# All ranks must use the same checkpoint directory.
if rank == 0:
ckpt_dir = tempfile.mkdtemp(prefix="cpu_offload_test_")
else:
ckpt_dir = ""
ckpt_dir_list = [ckpt_dir]
dist.broadcast_object_list(ckpt_dir_list, src=0)
ckpt_dir = ckpt_dir_list[0]
try:
dcp.save(flat_state, checkpoint_id=ckpt_dir)
dist.barrier()
if rank == 0:
logger.info("DCP save completed to %s", ckpt_dir)
# --- Reference: fresh optimizer, load via deepcopy (no serialization).
optim_ref, mp_ref, ap_ref = make_optimizer(True)
for i in range(num_muon):
mp_ref[i].data = mp_off[i].data.clone()
for i in range(num_adamw):
ap_ref[i].data = ap_off[i].data.clone()
with pytest.raises(
RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load"
):
optim_ref.load_state_dict(copy.deepcopy(sd_off))
optim_ref.turn_off_cpu_offload()
optim_ref.load_state_dict(copy.deepcopy(sd_off))
optim_ref.turn_on_cpu_offload()
# --- Resumed: fresh optimizer, load via DCP.
optim_resumed, mp_resumed, ap_resumed = make_optimizer(True)
for i in range(num_muon):
mp_resumed[i].data = mp_off[i].data.clone()
for i in range(num_adamw):
ap_resumed[i].data = ap_off[i].data.clone()
flat_target = {k: torch.zeros_like(v) for k, v in flat_state.items()}
dcp.load(flat_target, checkpoint_id=ckpt_dir)
dist.barrier()
sd_loaded = copy.deepcopy(sd_off)
for param_idx, param_state in sd_loaded["state"].items():
for key in list(param_state.keys()):
flat_key = f"state.{param_idx}.{key}"
if flat_key in flat_target:
param_state[key] = flat_target[flat_key]
with pytest.raises(
RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load"
):
optim_resumed.load_state_dict(copy.deepcopy(sd_loaded))
optim_resumed.turn_off_cpu_offload()
optim_resumed.load_state_dict(sd_loaded)
optim_resumed.turn_on_cpu_offload()
if rank == 0:
logger.info("Both optimizers loaded, starting comparison steps")
finally:
dist.barrier()
if rank == 0:
shutil.rmtree(ckpt_dir, ignore_errors=True)
# Second half: reference continues, resumed uses loaded state.
for step_idx in range(num_steps, num_steps * 2):
for i in range(num_muon):
g = all_grads_muon[step_idx][i]
mp_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
mp_resumed[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
for i in range(num_adamw):
g = all_grads_adamw[step_idx][i]
ap_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
ap_resumed[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
optim_ref.step()
optim_resumed.step()
# Compare final params: bitwise exact (DCP preserves DTensor identity).
for i in range(num_muon):
ref_full = mp_ref[i].data.full_tensor()
res_full = mp_resumed[i].data.full_tensor()
torch.testing.assert_close(ref_full, res_full, atol=0, rtol=0)
for i in range(num_adamw):
ref_full = ap_ref[i].data.full_tensor()
res_full = ap_resumed[i].data.full_tensor()
torch.testing.assert_close(ref_full, res_full, atol=0, rtol=0)
# Verify offload is active on the resumed optimizer.
for p in mp_resumed:
state = optim_resumed.state[p]
if "momentum_buffer" in state:
buf = state["momentum_buffer"]
local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf
assert local_buf.untyped_storage().size() == 0, (
"Resumed optimizer should have offloaded state after step()"
)
set_ns_compile(True)
if rank == 0:
logger.info("PASSED: test_state_dict_save_load")
def test_checkpoint_memory(rank, world_size):
"""Verify checkpoint APIs require offload to be disabled explicitly."""
from optimizer.muon import Muon
from optimizer.newton_schulz import set_ns_compile
set_ns_compile(False)
torch.manual_seed(42)
mesh = _make_mesh(world_size)
dim0, dim1 = 512, 1024
num_params = 8
params, names = [], []
for i in range(num_params):
full = torch.randn(dim0, dim1, device="cuda")
dt = distribute_tensor(full, mesh, [Shard(0)])
p = torch.nn.Parameter(dt)
p.grad = distribute_tensor(
torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
)
params.append(p)
names.append(f"layer.{i}.weight")
param_groups = [
{
"params": params,
"names": names,
"use_muon": True,
"lr": 0.02,
"weight_decay": 0.01,
"momentum": 0.95,
"nesterov": True,
"ns_steps": 5,
"none_grad": False,
}
]
optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
optim.turn_on_cpu_offload()
# Step 1: run a step so offload initializes.
optim.step()
torch.cuda.synchronize()
mem_after_step = torch.cuda.memory_allocated()
# Calculate expected state size (momentum buffers, bf16).
state_bytes = 0
for p in params:
state = optim.state[p]
if "momentum_buffer" in state:
buf = state["momentum_buffer"]
local = buf._local_tensor if isinstance(buf, DTensor) else buf
# Storage is freed, so use the tracked size.
state_bytes += optim._cpu_offload_pool._storage_nbytes[id(buf)]
if rank == 0:
logger.info(
"After step (offloaded): GPU alloc=%.2f MB, expected state size=%.2f MB",
mem_after_step / 1024**2,
state_bytes / 1024**2,
)
with pytest.raises(
RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint save"
):
optim.state_dict()
optim.turn_off_cpu_offload()
torch.cuda.synchronize()
mem_after_turn_off = torch.cuda.memory_allocated()
sd_for_load = copy.deepcopy(optim.state_dict())
if rank == 0:
logger.info(
"After turn_off_cpu_offload: GPU alloc=%.2f MB",
mem_after_turn_off / 1024**2,
)
assert mem_after_turn_off > mem_after_step, (
f"turn_off_cpu_offload() should reload states to GPU. "
f"After offload: {mem_after_step / 1024**2:.2f} MB, "
f"After turn_off: {mem_after_turn_off / 1024**2:.2f} MB"
)
optim.turn_on_cpu_offload()
torch.cuda.synchronize()
mem_after_turn_on = torch.cuda.memory_allocated()
if rank == 0:
logger.info(
"After turn_on_cpu_offload: GPU alloc=%.2f MB", mem_after_turn_on / 1024**2
)
assert mem_after_turn_on <= mem_after_step + 4 * 1024 * 1024, (
f"turn_on_cpu_offload() should return memory to offloaded level. "
f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), "
f"got {mem_after_turn_on / 1024**2:.2f} MB"
)
for p in params:
p.grad = distribute_tensor(
torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
)
optim.step()
torch.cuda.synchronize()
mem_after_next_step = torch.cuda.memory_allocated()
if rank == 0:
logger.info(
"After next step (re-offloaded): GPU alloc=%.2f MB",
mem_after_next_step / 1024**2,
)
# Allow 4 MB tolerance for CUDA allocator fragmentation.
assert mem_after_next_step <= mem_after_step + 4 * 1024 * 1024, (
f"Memory should return to offloaded level after step(). "
f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), "
f"got {mem_after_next_step / 1024**2:.2f} MB"
)
with pytest.raises(
RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load"
):
optim.load_state_dict(copy.deepcopy(sd_for_load))
optim.turn_off_cpu_offload()
optim.load_state_dict(sd_for_load)
torch.cuda.synchronize()
mem_after_load = torch.cuda.memory_allocated()
if rank == 0:
logger.info(
"After load_state_dict with offload disabled: GPU alloc=%.2f MB",
mem_after_load / 1024**2,
)
assert mem_after_load >= mem_after_turn_off, (
"Loaded optimizer state should stay on GPU while offload is disabled"
)
optim.turn_on_cpu_offload()
torch.cuda.synchronize()
pool = optim._cpu_offload_pool
assert pool._initialized, (
"Offload pool should be initialized after re-enabling offload"
)
for grp in pool._groups.values():
assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned"
# Step 5: verify the loaded optimizer can still step correctly.
for p in params:
p.grad = distribute_tensor(
torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
)
optim.step()
torch.cuda.synchronize()
mem_final = torch.cuda.memory_allocated()
assert mem_final <= mem_after_step + 4 * 1024 * 1024, (
f"Final memory should be at offloaded level. "
f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), "
f"got {mem_final / 1024**2:.2f} MB"
)
set_ns_compile(True)
if rank == 0:
logger.info("PASSED: test_checkpoint_memory")
def main():
rank, world_size = _setup()
try:
test_correctness(rank, world_size)
test_memory(rank, world_size)
test_adamw_offload(rank, world_size)
test_memory_savings(rank, world_size)
test_toggle_correctness(rank, world_size)
test_leak(rank, world_size)
test_state_dict_save_load(rank, world_size)
test_checkpoint_memory(rank, world_size)
if rank == 0:
logger.info("=" * 50)
logger.info("ALL CPU OFFLOAD TESTS PASSED")
logger.info("=" * 50)
finally:
dist.destroy_process_group()
if __name__ == "__main__":
main()