Kernels
wyldecat commited on
Commit
d7e2d58
·
unverified ·
1 Parent(s): e8e2c81

feat: CPUOffloadPool.reload_group에 sync_streams 파라미터 추가 (#31)

Browse files

* feat: tag-based per-group reload for CPU offload + manual offload mode

Extend CPUOffloadPool + Muon so callers can drive offload/reload
explicitly (layer-lockstep overlap) rather than implicitly inside
optimizer.step.

CPUOffloadPool (torch-ext/optimizer/cpu_offload.py):
* ``track(tensor, tag=None)`` — optional tag per managed tensor.
* ``reload_group(tag, sync_streams=())`` — reload just the tensors
tagged with ``tag``; the reload stream ``wait_stream`` s on each
entry in ``sync_streams`` before issuing H2D. This avoids
allocator cross-stream reuse races under
``PYTORCH_ALLOC_CONF=expandable_segments:True``: if the block
returned by ``storage.resize_`` was last touched on an FSDP
all-gather stream, waiting on that stream enforces FIFO ordering
between the prior use and our H2D write.
* ``reload_untagged()`` — bulk-reload everything not attached to a
tag (for the non-expert portion of the optimizer state in layer-
lockstep flows).
* ``wait_reload()`` is now self-clearing (resets ``_reload_event``
after one wait).

Muon (torch-ext/optimizer/muon.py):
* ``manual_offload`` flag: when set, ``step()`` skips its own
``reload`` / ``offload`` calls so the caller can drive them.
* ``set_param_tags(id(param) -> tag)``: propagated to the pool on
first offload so ``reload_group`` picks the right tensors up.
* New public helpers: ``reload_group``, ``reload_untagged``,
``wait_reload``, ``offload`` — mirroring pool semantics.
* Baseline (non-manual) path: add explicit ``wait_reload`` after
``pool.reload()`` in ``step`` and ``turn_off_cpu_offload``.

Default ``sync_streams=()`` keeps existing callers behaviourally
unchanged.

* style: yapf

* chore: whitelist math notation 'Ot' for typos hook

_typos.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [default.extend-words]
2
+ # Math notation used in docs/muon-clip.md (O subscript t, update step output)
3
+ Ot = "Ot"
test/test_cpu_offload.py CHANGED
@@ -29,7 +29,8 @@ def _setup():
29
 
30
 
31
  def _make_mesh(world_size):
32
- return dist.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",))
 
33
 
34
 
35
  def test_correctness(rank, world_size):
@@ -47,11 +48,12 @@ def test_correctness(rank, world_size):
47
  num_steps = 3
48
 
49
  # Pre-generate all data on all ranks (same seed → same values).
50
- full_params = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)]
51
- full_grads = [
52
- [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)]
53
- for _ in range(num_steps)
54
  ]
 
 
 
55
 
56
  def make_optimizer(cpu_offload):
57
  params, names = [], []
@@ -60,19 +62,17 @@ def test_correctness(rank, world_size):
60
  p = torch.nn.Parameter(dt)
61
  params.append(p)
62
  names.append(f"layer.{i}.weight")
63
- param_groups = [
64
- {
65
- "params": params,
66
- "names": names,
67
- "use_muon": True,
68
- "lr": 0.02,
69
- "weight_decay": 0.01,
70
- "momentum": 0.95,
71
- "nesterov": True,
72
- "ns_steps": 5,
73
- "none_grad": False,
74
- }
75
- ]
76
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
77
  if cpu_offload:
78
  optim.turn_on_cpu_offload()
@@ -121,25 +121,22 @@ def test_memory(rank, world_size):
121
  full = torch.randn(dim0, dim1, device="cuda")
122
  dt = distribute_tensor(full, mesh, [Shard(0)])
123
  p = torch.nn.Parameter(dt)
124
- p.grad = distribute_tensor(
125
- torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
126
- )
127
  params.append(p)
128
  names.append(f"layer.{i}.weight")
129
 
130
- param_groups = [
131
- {
132
- "params": params,
133
- "names": names,
134
- "use_muon": True,
135
- "lr": 0.02,
136
- "weight_decay": 0.01,
137
- "momentum": 0.95,
138
- "nesterov": True,
139
- "ns_steps": 5,
140
- "none_grad": False,
141
- }
142
- ]
143
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
144
  optim.turn_on_cpu_offload()
145
 
@@ -155,8 +152,7 @@ def test_memory(rank, world_size):
155
  local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf
156
  assert local_buf.untyped_storage().size() == 0, (
157
  f"Expected freed GPU storage after offload, got "
158
- f"{local_buf.untyped_storage().size()} bytes"
159
- )
160
 
161
  # Verify CPU pool has pinned buffers.
162
  pool = optim._cpu_offload_pool
@@ -166,9 +162,8 @@ def test_memory(rank, world_size):
166
 
167
  # Run another step to verify reload + compute + offload cycle works.
168
  for p in params:
169
- p.grad = distribute_tensor(
170
- torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
171
- )
172
  optim.step()
173
  torch.cuda.synchronize()
174
 
@@ -217,26 +212,21 @@ def test_adamw_offload(rank, world_size):
217
  adamw_names.append(f"layer.{i}.bias")
218
 
219
  # Pre-generate grads.
220
- muon_grads = [
221
- [torch.randn(64, 128, device="cuda") for _ in range(4)]
222
- for _ in range(num_steps)
223
- ]
224
- adamw_grads = [
225
- [torch.randn(128, device="cuda") for _ in range(3)] for _ in range(num_steps)
226
- ]
227
 
228
  def make_optimizer(cpu_offload):
229
  mp = [
230
  torch.nn.Parameter(
231
- distribute_tensor(p.data.full_tensor().clone(), mesh, [Shard(0)])
232
- )
233
- for p in muon_params
234
  ]
235
  ap = [
236
  torch.nn.Parameter(
237
- distribute_tensor(p.data.full_tensor().clone(), mesh, [Shard(0)])
238
- )
239
- for p in adamw_params
240
  ]
241
  param_groups = [
242
  {
@@ -306,8 +296,7 @@ def test_adamw_offload(rank, world_size):
306
  t = state[key]
307
  local_t = t._local_tensor if isinstance(t, DTensor) else t
308
  assert local_t.untyped_storage().size() == 0, (
309
- f"AdamW {key} storage not freed after offload"
310
- )
311
 
312
  set_ns_compile(True)
313
  if rank == 0:
@@ -335,25 +324,22 @@ def test_memory_savings(rank, world_size):
335
  full = torch.randn(dim0, dim1, device="cuda")
336
  dt = distribute_tensor(full, mesh, [Shard(0)])
337
  p = torch.nn.Parameter(dt)
338
- p.grad = distribute_tensor(
339
- torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
340
- )
341
  params.append(p)
342
  names.append(f"layer.{i}.weight")
343
 
344
- param_groups = [
345
- {
346
- "params": params,
347
- "names": names,
348
- "use_muon": True,
349
- "lr": 0.02,
350
- "weight_decay": 0.01,
351
- "momentum": 0.95,
352
- "nesterov": True,
353
- "ns_steps": 5,
354
- "none_grad": False,
355
- }
356
- ]
357
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
358
  if cpu_offload:
359
  optim.turn_on_cpu_offload()
@@ -370,16 +356,17 @@ def test_memory_savings(rank, world_size):
370
  mem_with_offload = run_step(True)
371
 
372
  if rank == 0:
373
- logger.info("Memory without offload: %.2f MB", mem_no_offload / 1024**2)
374
- logger.info("Memory with offload: %.2f MB", mem_with_offload / 1024**2)
 
 
375
  saved = mem_no_offload - mem_with_offload
376
  logger.info("Memory saved: %.2f MB", saved / 1024**2)
377
 
378
  assert mem_with_offload < mem_no_offload, (
379
  f"Expected memory reduction with CPU offload. "
380
  f"Without: {mem_no_offload / 1024**2:.2f} MB, "
381
- f"With: {mem_with_offload / 1024**2:.2f} MB"
382
- )
383
 
384
  set_ns_compile(True)
385
  if rank == 0:
@@ -400,11 +387,12 @@ def test_toggle_correctness(rank, world_size):
400
  num_params = 4
401
  num_steps = 6
402
 
403
- full_params = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)]
404
- full_grads = [
405
- [torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)]
406
- for _ in range(num_steps)
407
  ]
 
 
 
408
 
409
  def make_optimizer():
410
  params, names = [], []
@@ -413,19 +401,17 @@ def test_toggle_correctness(rank, world_size):
413
  p = torch.nn.Parameter(dt)
414
  params.append(p)
415
  names.append(f"layer.{i}.weight")
416
- param_groups = [
417
- {
418
- "params": params,
419
- "names": names,
420
- "use_muon": True,
421
- "lr": 0.02,
422
- "weight_decay": 0.01,
423
- "momentum": 0.95,
424
- "nesterov": True,
425
- "ns_steps": 5,
426
- "none_grad": False,
427
- }
428
- ]
429
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
430
  return optim, params
431
 
@@ -446,7 +432,8 @@ def test_toggle_correctness(rank, world_size):
446
  for i in range(num_params):
447
  g = full_grads[step_idx][i]
448
  params_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
449
- params_toggle[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
 
450
 
451
  optim_ref.step()
452
  optim_toggle.step()
@@ -492,19 +479,17 @@ def test_leak(rank, world_size):
492
  params.append(p)
493
  names.append(f"layer.{i}.weight")
494
 
495
- param_groups = [
496
- {
497
- "params": params,
498
- "names": names,
499
- "use_muon": True,
500
- "lr": 0.02,
501
- "weight_decay": 0.01,
502
- "momentum": 0.95,
503
- "nesterov": True,
504
- "ns_steps": 5,
505
- "none_grad": False,
506
- }
507
- ]
508
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
509
  optim.turn_on_cpu_offload()
510
 
@@ -519,9 +504,8 @@ def test_leak(rank, world_size):
519
 
520
  for step_idx in range(num_steps):
521
  for p in params:
522
- p.grad = distribute_tensor(
523
- torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
524
- )
525
 
526
  optim.step()
527
  torch.cuda.synchronize()
@@ -564,8 +548,7 @@ def test_leak(rank, world_size):
564
  # GPU memory should not grow beyond warmup baseline.
565
  assert gpu_final <= gpu_after_warmup, (
566
  f"GPU memory leak detected! Warmup: {gpu_after_warmup / 1024**2:.2f} MB, "
567
- f"Final: {gpu_final / 1024**2:.2f} MB"
568
- )
569
 
570
  # CPU RSS should not grow more than 50 MB over warmup (allows for minor
571
  # Python/CUDA runtime overhead but catches real leaks).
@@ -573,12 +556,12 @@ def test_leak(rank, world_size):
573
  assert cpu_growth < 50, (
574
  f"CPU memory leak detected! Growth: {cpu_growth:.2f} MB over "
575
  f"{num_steps - 2} steps (warmup={cpu_after_warmup:.2f} MB, "
576
- f"final={cpu_final:.2f} MB)"
577
- )
578
 
579
  set_ns_compile(True)
580
  if rank == 0:
581
- logger.info("PASSED: test_leak (GPU stable, CPU growth=%.2f MB)", cpu_growth)
 
582
 
583
 
584
  def test_state_dict_save_load(rank, world_size):
@@ -606,28 +589,26 @@ def test_state_dict_save_load(rank, world_size):
606
  num_steps = 3
607
 
608
  # Pre-generate all data.
609
- muon_init = [torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon)]
610
- adamw_init = [torch.randn(dim1, device="cuda") for _ in range(num_adamw)]
611
- all_grads_muon = [
612
- [torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon)]
613
- for _ in range(num_steps * 2)
614
- ]
615
- all_grads_adamw = [
616
- [torch.randn(dim1, device="cuda") for _ in range(num_adamw)]
617
- for _ in range(num_steps * 2)
618
  ]
 
 
 
 
 
 
 
619
 
620
  def make_optimizer(cpu_offload):
621
  mp = [
622
  torch.nn.Parameter(
623
- distribute_tensor(muon_init[i].clone(), mesh, [Shard(0)])
624
- )
625
  for i in range(num_muon)
626
  ]
627
  ap = [
628
  torch.nn.Parameter(
629
- distribute_tensor(adamw_init[i].clone(), mesh, [Shard(0)])
630
- )
631
  for i in range(num_adamw)
632
  ]
633
  param_groups = [
@@ -666,17 +647,15 @@ def test_state_dict_save_load(rank, world_size):
666
  for step_idx in range(num_steps):
667
  for i in range(num_muon):
668
  mp_off[i].grad = distribute_tensor(
669
- all_grads_muon[step_idx][i].clone(), mesh, [Shard(0)]
670
- )
671
  for i in range(num_adamw):
672
  ap_off[i].grad = distribute_tensor(
673
- all_grads_adamw[step_idx][i].clone(), mesh, [Shard(0)]
674
- )
675
  optim_off.step()
676
 
677
  with pytest.raises(
678
- RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint save"
679
- ):
680
  optim_off.state_dict()
681
 
682
  optim_off.turn_off_cpu_offload()
@@ -688,8 +667,7 @@ def test_state_dict_save_load(rank, world_size):
688
  if isinstance(val, torch.Tensor) and val.is_floating_point():
689
  assert val.untyped_storage().size() > 0, (
690
  f"state_dict() returned empty storage for key '{key}' — "
691
- f"offload reload is broken"
692
- )
693
 
694
  if rank == 0:
695
  logger.info("state_dict() contains valid (non-empty) tensors")
@@ -724,8 +702,8 @@ def test_state_dict_save_load(rank, world_size):
724
  for i in range(num_adamw):
725
  ap_ref[i].data = ap_off[i].data.clone()
726
  with pytest.raises(
727
- RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load"
728
- ):
729
  optim_ref.load_state_dict(copy.deepcopy(sd_off))
730
  optim_ref.turn_off_cpu_offload()
731
  optim_ref.load_state_dict(copy.deepcopy(sd_off))
@@ -749,8 +727,8 @@ def test_state_dict_save_load(rank, world_size):
749
  if flat_key in flat_target:
750
  param_state[key] = flat_target[flat_key]
751
  with pytest.raises(
752
- RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load"
753
- ):
754
  optim_resumed.load_state_dict(copy.deepcopy(sd_loaded))
755
  optim_resumed.turn_off_cpu_offload()
756
  optim_resumed.load_state_dict(sd_loaded)
@@ -795,8 +773,7 @@ def test_state_dict_save_load(rank, world_size):
795
  buf = state["momentum_buffer"]
796
  local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf
797
  assert local_buf.untyped_storage().size() == 0, (
798
- "Resumed optimizer should have offloaded state after step()"
799
- )
800
 
801
  set_ns_compile(True)
802
  if rank == 0:
@@ -821,25 +798,22 @@ def test_checkpoint_memory(rank, world_size):
821
  full = torch.randn(dim0, dim1, device="cuda")
822
  dt = distribute_tensor(full, mesh, [Shard(0)])
823
  p = torch.nn.Parameter(dt)
824
- p.grad = distribute_tensor(
825
- torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
826
- )
827
  params.append(p)
828
  names.append(f"layer.{i}.weight")
829
 
830
- param_groups = [
831
- {
832
- "params": params,
833
- "names": names,
834
- "use_muon": True,
835
- "lr": 0.02,
836
- "weight_decay": 0.01,
837
- "momentum": 0.95,
838
- "nesterov": True,
839
- "ns_steps": 5,
840
- "none_grad": False,
841
- }
842
- ]
843
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
844
  optim.turn_on_cpu_offload()
845
 
@@ -867,8 +841,8 @@ def test_checkpoint_memory(rank, world_size):
867
  )
868
 
869
  with pytest.raises(
870
- RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint save"
871
- ):
872
  optim.state_dict()
873
 
874
  optim.turn_off_cpu_offload()
@@ -885,28 +859,24 @@ def test_checkpoint_memory(rank, world_size):
885
  assert mem_after_turn_off > mem_after_step, (
886
  f"turn_off_cpu_offload() should reload states to GPU. "
887
  f"After offload: {mem_after_step / 1024**2:.2f} MB, "
888
- f"After turn_off: {mem_after_turn_off / 1024**2:.2f} MB"
889
- )
890
 
891
  optim.turn_on_cpu_offload()
892
  torch.cuda.synchronize()
893
  mem_after_turn_on = torch.cuda.memory_allocated()
894
 
895
  if rank == 0:
896
- logger.info(
897
- "After turn_on_cpu_offload: GPU alloc=%.2f MB", mem_after_turn_on / 1024**2
898
- )
899
 
900
  assert mem_after_turn_on <= mem_after_step + 4 * 1024 * 1024, (
901
  f"turn_on_cpu_offload() should return memory to offloaded level. "
902
  f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), "
903
- f"got {mem_after_turn_on / 1024**2:.2f} MB"
904
- )
905
 
906
  for p in params:
907
- p.grad = distribute_tensor(
908
- torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
909
- )
910
  optim.step()
911
  torch.cuda.synchronize()
912
 
@@ -922,12 +892,11 @@ def test_checkpoint_memory(rank, world_size):
922
  assert mem_after_next_step <= mem_after_step + 4 * 1024 * 1024, (
923
  f"Memory should return to offloaded level after step(). "
924
  f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), "
925
- f"got {mem_after_next_step / 1024**2:.2f} MB"
926
- )
927
 
928
  with pytest.raises(
929
- RuntimeError, match="turn_off_cpu_offload\\(\\) before checkpoint load"
930
- ):
931
  optim.load_state_dict(copy.deepcopy(sd_for_load))
932
 
933
  optim.turn_off_cpu_offload()
@@ -943,24 +912,21 @@ def test_checkpoint_memory(rank, world_size):
943
  )
944
 
945
  assert mem_after_load >= mem_after_turn_off, (
946
- "Loaded optimizer state should stay on GPU while offload is disabled"
947
- )
948
 
949
  optim.turn_on_cpu_offload()
950
  torch.cuda.synchronize()
951
 
952
  pool = optim._cpu_offload_pool
953
  assert pool._initialized, (
954
- "Offload pool should be initialized after re-enabling offload"
955
- )
956
  for grp in pool._groups.values():
957
  assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned"
958
 
959
  # Step 5: verify the loaded optimizer can still step correctly.
960
  for p in params:
961
- p.grad = distribute_tensor(
962
- torch.randn(dim0, dim1, device="cuda"), mesh, [Shard(0)]
963
- )
964
  optim.step()
965
  torch.cuda.synchronize()
966
 
@@ -968,8 +934,7 @@ def test_checkpoint_memory(rank, world_size):
968
  assert mem_final <= mem_after_step + 4 * 1024 * 1024, (
969
  f"Final memory should be at offloaded level. "
970
  f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), "
971
- f"got {mem_final / 1024**2:.2f} MB"
972
- )
973
 
974
  set_ns_compile(True)
975
  if rank == 0:
 
29
 
30
 
31
  def _make_mesh(world_size):
32
+ return dist.init_device_mesh("cuda", (world_size, ),
33
+ mesh_dim_names=("dp", ))
34
 
35
 
36
  def test_correctness(rank, world_size):
 
48
  num_steps = 3
49
 
50
  # Pre-generate all data on all ranks (same seed → same values).
51
+ full_params = [
52
+ torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)
 
 
53
  ]
54
+ full_grads = [[
55
+ torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)
56
+ ] for _ in range(num_steps)]
57
 
58
  def make_optimizer(cpu_offload):
59
  params, names = [], []
 
62
  p = torch.nn.Parameter(dt)
63
  params.append(p)
64
  names.append(f"layer.{i}.weight")
65
+ param_groups = [{
66
+ "params": params,
67
+ "names": names,
68
+ "use_muon": True,
69
+ "lr": 0.02,
70
+ "weight_decay": 0.01,
71
+ "momentum": 0.95,
72
+ "nesterov": True,
73
+ "ns_steps": 5,
74
+ "none_grad": False,
75
+ }]
 
 
76
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
77
  if cpu_offload:
78
  optim.turn_on_cpu_offload()
 
121
  full = torch.randn(dim0, dim1, device="cuda")
122
  dt = distribute_tensor(full, mesh, [Shard(0)])
123
  p = torch.nn.Parameter(dt)
124
+ p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
125
+ mesh, [Shard(0)])
 
126
  params.append(p)
127
  names.append(f"layer.{i}.weight")
128
 
129
+ param_groups = [{
130
+ "params": params,
131
+ "names": names,
132
+ "use_muon": True,
133
+ "lr": 0.02,
134
+ "weight_decay": 0.01,
135
+ "momentum": 0.95,
136
+ "nesterov": True,
137
+ "ns_steps": 5,
138
+ "none_grad": False,
139
+ }]
 
 
140
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
141
  optim.turn_on_cpu_offload()
142
 
 
152
  local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf
153
  assert local_buf.untyped_storage().size() == 0, (
154
  f"Expected freed GPU storage after offload, got "
155
+ f"{local_buf.untyped_storage().size()} bytes")
 
156
 
157
  # Verify CPU pool has pinned buffers.
158
  pool = optim._cpu_offload_pool
 
162
 
163
  # Run another step to verify reload + compute + offload cycle works.
164
  for p in params:
165
+ p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
166
+ mesh, [Shard(0)])
 
167
  optim.step()
168
  torch.cuda.synchronize()
169
 
 
212
  adamw_names.append(f"layer.{i}.bias")
213
 
214
  # Pre-generate grads.
215
+ muon_grads = [[torch.randn(64, 128, device="cuda") for _ in range(4)]
216
+ for _ in range(num_steps)]
217
+ adamw_grads = [[torch.randn(128, device="cuda") for _ in range(3)]
218
+ for _ in range(num_steps)]
 
 
 
219
 
220
  def make_optimizer(cpu_offload):
221
  mp = [
222
  torch.nn.Parameter(
223
+ distribute_tensor(p.data.full_tensor().clone(), mesh,
224
+ [Shard(0)])) for p in muon_params
 
225
  ]
226
  ap = [
227
  torch.nn.Parameter(
228
+ distribute_tensor(p.data.full_tensor().clone(), mesh,
229
+ [Shard(0)])) for p in adamw_params
 
230
  ]
231
  param_groups = [
232
  {
 
296
  t = state[key]
297
  local_t = t._local_tensor if isinstance(t, DTensor) else t
298
  assert local_t.untyped_storage().size() == 0, (
299
+ f"AdamW {key} storage not freed after offload")
 
300
 
301
  set_ns_compile(True)
302
  if rank == 0:
 
324
  full = torch.randn(dim0, dim1, device="cuda")
325
  dt = distribute_tensor(full, mesh, [Shard(0)])
326
  p = torch.nn.Parameter(dt)
327
+ p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
328
+ mesh, [Shard(0)])
 
329
  params.append(p)
330
  names.append(f"layer.{i}.weight")
331
 
332
+ param_groups = [{
333
+ "params": params,
334
+ "names": names,
335
+ "use_muon": True,
336
+ "lr": 0.02,
337
+ "weight_decay": 0.01,
338
+ "momentum": 0.95,
339
+ "nesterov": True,
340
+ "ns_steps": 5,
341
+ "none_grad": False,
342
+ }]
 
 
343
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
344
  if cpu_offload:
345
  optim.turn_on_cpu_offload()
 
356
  mem_with_offload = run_step(True)
357
 
358
  if rank == 0:
359
+ logger.info("Memory without offload: %.2f MB",
360
+ mem_no_offload / 1024**2)
361
+ logger.info("Memory with offload: %.2f MB",
362
+ mem_with_offload / 1024**2)
363
  saved = mem_no_offload - mem_with_offload
364
  logger.info("Memory saved: %.2f MB", saved / 1024**2)
365
 
366
  assert mem_with_offload < mem_no_offload, (
367
  f"Expected memory reduction with CPU offload. "
368
  f"Without: {mem_no_offload / 1024**2:.2f} MB, "
369
+ f"With: {mem_with_offload / 1024**2:.2f} MB")
 
370
 
371
  set_ns_compile(True)
372
  if rank == 0:
 
387
  num_params = 4
388
  num_steps = 6
389
 
390
+ full_params = [
391
+ torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)
 
 
392
  ]
393
+ full_grads = [[
394
+ torch.randn(dim0, dim1, device="cuda") for _ in range(num_params)
395
+ ] for _ in range(num_steps)]
396
 
397
  def make_optimizer():
398
  params, names = [], []
 
401
  p = torch.nn.Parameter(dt)
402
  params.append(p)
403
  names.append(f"layer.{i}.weight")
404
+ param_groups = [{
405
+ "params": params,
406
+ "names": names,
407
+ "use_muon": True,
408
+ "lr": 0.02,
409
+ "weight_decay": 0.01,
410
+ "momentum": 0.95,
411
+ "nesterov": True,
412
+ "ns_steps": 5,
413
+ "none_grad": False,
414
+ }]
 
 
415
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
416
  return optim, params
417
 
 
432
  for i in range(num_params):
433
  g = full_grads[step_idx][i]
434
  params_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)])
435
+ params_toggle[i].grad = distribute_tensor(g.clone(), mesh,
436
+ [Shard(0)])
437
 
438
  optim_ref.step()
439
  optim_toggle.step()
 
479
  params.append(p)
480
  names.append(f"layer.{i}.weight")
481
 
482
+ param_groups = [{
483
+ "params": params,
484
+ "names": names,
485
+ "use_muon": True,
486
+ "lr": 0.02,
487
+ "weight_decay": 0.01,
488
+ "momentum": 0.95,
489
+ "nesterov": True,
490
+ "ns_steps": 5,
491
+ "none_grad": False,
492
+ }]
 
 
493
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
494
  optim.turn_on_cpu_offload()
495
 
 
504
 
505
  for step_idx in range(num_steps):
506
  for p in params:
507
+ p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
508
+ mesh, [Shard(0)])
 
509
 
510
  optim.step()
511
  torch.cuda.synchronize()
 
548
  # GPU memory should not grow beyond warmup baseline.
549
  assert gpu_final <= gpu_after_warmup, (
550
  f"GPU memory leak detected! Warmup: {gpu_after_warmup / 1024**2:.2f} MB, "
551
+ f"Final: {gpu_final / 1024**2:.2f} MB")
 
552
 
553
  # CPU RSS should not grow more than 50 MB over warmup (allows for minor
554
  # Python/CUDA runtime overhead but catches real leaks).
 
556
  assert cpu_growth < 50, (
557
  f"CPU memory leak detected! Growth: {cpu_growth:.2f} MB over "
558
  f"{num_steps - 2} steps (warmup={cpu_after_warmup:.2f} MB, "
559
+ f"final={cpu_final:.2f} MB)")
 
560
 
561
  set_ns_compile(True)
562
  if rank == 0:
563
+ logger.info("PASSED: test_leak (GPU stable, CPU growth=%.2f MB)",
564
+ cpu_growth)
565
 
566
 
567
  def test_state_dict_save_load(rank, world_size):
 
589
  num_steps = 3
590
 
591
  # Pre-generate all data.
592
+ muon_init = [
593
+ torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon)
 
 
 
 
 
 
 
594
  ]
595
+ adamw_init = [torch.randn(dim1, device="cuda") for _ in range(num_adamw)]
596
+ all_grads_muon = [[
597
+ torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon)
598
+ ] for _ in range(num_steps * 2)]
599
+ all_grads_adamw = [[
600
+ torch.randn(dim1, device="cuda") for _ in range(num_adamw)
601
+ ] for _ in range(num_steps * 2)]
602
 
603
  def make_optimizer(cpu_offload):
604
  mp = [
605
  torch.nn.Parameter(
606
+ distribute_tensor(muon_init[i].clone(), mesh, [Shard(0)]))
 
607
  for i in range(num_muon)
608
  ]
609
  ap = [
610
  torch.nn.Parameter(
611
+ distribute_tensor(adamw_init[i].clone(), mesh, [Shard(0)]))
 
612
  for i in range(num_adamw)
613
  ]
614
  param_groups = [
 
647
  for step_idx in range(num_steps):
648
  for i in range(num_muon):
649
  mp_off[i].grad = distribute_tensor(
650
+ all_grads_muon[step_idx][i].clone(), mesh, [Shard(0)])
 
651
  for i in range(num_adamw):
652
  ap_off[i].grad = distribute_tensor(
653
+ all_grads_adamw[step_idx][i].clone(), mesh, [Shard(0)])
 
654
  optim_off.step()
655
 
656
  with pytest.raises(
657
+ RuntimeError,
658
+ match="turn_off_cpu_offload\\(\\) before checkpoint save"):
659
  optim_off.state_dict()
660
 
661
  optim_off.turn_off_cpu_offload()
 
667
  if isinstance(val, torch.Tensor) and val.is_floating_point():
668
  assert val.untyped_storage().size() > 0, (
669
  f"state_dict() returned empty storage for key '{key}' — "
670
+ f"offload reload is broken")
 
671
 
672
  if rank == 0:
673
  logger.info("state_dict() contains valid (non-empty) tensors")
 
702
  for i in range(num_adamw):
703
  ap_ref[i].data = ap_off[i].data.clone()
704
  with pytest.raises(
705
+ RuntimeError,
706
+ match="turn_off_cpu_offload\\(\\) before checkpoint load"):
707
  optim_ref.load_state_dict(copy.deepcopy(sd_off))
708
  optim_ref.turn_off_cpu_offload()
709
  optim_ref.load_state_dict(copy.deepcopy(sd_off))
 
727
  if flat_key in flat_target:
728
  param_state[key] = flat_target[flat_key]
729
  with pytest.raises(
730
+ RuntimeError,
731
+ match="turn_off_cpu_offload\\(\\) before checkpoint load"):
732
  optim_resumed.load_state_dict(copy.deepcopy(sd_loaded))
733
  optim_resumed.turn_off_cpu_offload()
734
  optim_resumed.load_state_dict(sd_loaded)
 
773
  buf = state["momentum_buffer"]
774
  local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf
775
  assert local_buf.untyped_storage().size() == 0, (
776
+ "Resumed optimizer should have offloaded state after step()")
 
777
 
778
  set_ns_compile(True)
779
  if rank == 0:
 
798
  full = torch.randn(dim0, dim1, device="cuda")
799
  dt = distribute_tensor(full, mesh, [Shard(0)])
800
  p = torch.nn.Parameter(dt)
801
+ p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
802
+ mesh, [Shard(0)])
 
803
  params.append(p)
804
  names.append(f"layer.{i}.weight")
805
 
806
+ param_groups = [{
807
+ "params": params,
808
+ "names": names,
809
+ "use_muon": True,
810
+ "lr": 0.02,
811
+ "weight_decay": 0.01,
812
+ "momentum": 0.95,
813
+ "nesterov": True,
814
+ "ns_steps": 5,
815
+ "none_grad": False,
816
+ }]
 
 
817
  optim = Muon(params=param_groups, chunk_size=2, warmup_step=1)
818
  optim.turn_on_cpu_offload()
819
 
 
841
  )
842
 
843
  with pytest.raises(
844
+ RuntimeError,
845
+ match="turn_off_cpu_offload\\(\\) before checkpoint save"):
846
  optim.state_dict()
847
 
848
  optim.turn_off_cpu_offload()
 
859
  assert mem_after_turn_off > mem_after_step, (
860
  f"turn_off_cpu_offload() should reload states to GPU. "
861
  f"After offload: {mem_after_step / 1024**2:.2f} MB, "
862
+ f"After turn_off: {mem_after_turn_off / 1024**2:.2f} MB")
 
863
 
864
  optim.turn_on_cpu_offload()
865
  torch.cuda.synchronize()
866
  mem_after_turn_on = torch.cuda.memory_allocated()
867
 
868
  if rank == 0:
869
+ logger.info("After turn_on_cpu_offload: GPU alloc=%.2f MB",
870
+ mem_after_turn_on / 1024**2)
 
871
 
872
  assert mem_after_turn_on <= mem_after_step + 4 * 1024 * 1024, (
873
  f"turn_on_cpu_offload() should return memory to offloaded level. "
874
  f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), "
875
+ f"got {mem_after_turn_on / 1024**2:.2f} MB")
 
876
 
877
  for p in params:
878
+ p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
879
+ mesh, [Shard(0)])
 
880
  optim.step()
881
  torch.cuda.synchronize()
882
 
 
892
  assert mem_after_next_step <= mem_after_step + 4 * 1024 * 1024, (
893
  f"Memory should return to offloaded level after step(). "
894
  f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), "
895
+ f"got {mem_after_next_step / 1024**2:.2f} MB")
 
896
 
897
  with pytest.raises(
898
+ RuntimeError,
899
+ match="turn_off_cpu_offload\\(\\) before checkpoint load"):
900
  optim.load_state_dict(copy.deepcopy(sd_for_load))
901
 
902
  optim.turn_off_cpu_offload()
 
912
  )
913
 
914
  assert mem_after_load >= mem_after_turn_off, (
915
+ "Loaded optimizer state should stay on GPU while offload is disabled")
 
916
 
917
  optim.turn_on_cpu_offload()
918
  torch.cuda.synchronize()
919
 
920
  pool = optim._cpu_offload_pool
921
  assert pool._initialized, (
922
+ "Offload pool should be initialized after re-enabling offload")
 
923
  for grp in pool._groups.values():
924
  assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned"
925
 
926
  # Step 5: verify the loaded optimizer can still step correctly.
927
  for p in params:
928
+ p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"),
929
+ mesh, [Shard(0)])
 
930
  optim.step()
931
  torch.cuda.synchronize()
932
 
 
934
  assert mem_final <= mem_after_step + 4 * 1024 * 1024, (
935
  f"Final memory should be at offloaded level. "
936
  f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), "
937
+ f"got {mem_final / 1024**2:.2f} MB")
 
938
 
939
  set_ns_compile(True)
940
  if rank == 0:
torch-ext/optimizer/cpu_offload.py CHANGED
@@ -20,6 +20,7 @@ from collections import defaultdict
20
 
21
  import torch
22
  from torch.distributed.tensor import DTensor
 
23
 
24
  logger = logging.getLogger(__name__)
25
 
@@ -35,6 +36,9 @@ class CPUOffloadPool:
35
  def __init__(self):
36
  self._managed: list[torch.Tensor] = []
37
  self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
 
 
 
38
 
39
  # Per-dtype group: populated on first offload.
40
  # dtype → dict with keys:
@@ -45,6 +49,8 @@ class CPUOffloadPool:
45
  self._groups: dict[torch.dtype, dict] = {}
46
 
47
  self._offload_stream: torch.cuda.Stream | None = None
 
 
48
  self._device: torch.device | None = None
49
  self._initialized: bool = False
50
  self._logged: bool = False
@@ -59,9 +65,28 @@ class CPUOffloadPool:
59
  if self._offload_stream is None:
60
  self._offload_stream = torch.cuda.Stream(device=self._device)
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # ------------------------------------------------------------------
63
- def track(self, tensor: torch.Tensor):
64
- """Register a GPU tensor for CPU offloading. Idempotent."""
 
 
 
 
 
65
  tid = id(tensor)
66
  if tid in self._storage_nbytes:
67
  return
@@ -73,7 +98,10 @@ class CPUOffloadPool:
73
  if storage.size() == 0:
74
  return
75
  self._storage_nbytes[tid] = storage.size()
 
76
  self._managed.append(tensor)
 
 
77
 
78
  # ------------------------------------------------------------------
79
  def _init_buffers(self):
@@ -93,7 +121,10 @@ class CPUOffloadPool:
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
97
  self._groups[dtype] = {
98
  "indices": indices,
99
  "offsets": offsets,
@@ -137,7 +168,8 @@ class CPUOffloadPool:
137
  for i, mgd_idx in enumerate(indices):
138
  local = self._local(self._managed[mgd_idx])
139
  off, n = offsets[i]
140
- cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
 
141
 
142
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
 
@@ -151,8 +183,7 @@ class CPUOffloadPool:
151
  raise RuntimeError(
152
  f"Tensor storage is already freed (size=0) before offload. "
153
  f"This indicates a double-free or external interference. "
154
- f"Tensor shape: {t.shape}, dtype: {t.dtype}"
155
- )
156
 
157
  if not self._logged:
158
  logger.info(
@@ -162,45 +193,172 @@ class CPUOffloadPool:
162
 
163
  # ------------------------------------------------------------------
164
  def reload(self):
165
- """Per-tensor H2D from CPU flat buffer on the default stream.
 
 
 
166
 
167
- Runs on the current (default) CUDA stream to avoid stream
168
- interaction issues with the parallel Muon pipeline. Since
169
- pinned CPU memory is the source, the copies overlap with
170
- GPU idle time between steps.
171
  """
172
  if not self._managed or not self._initialized:
173
  return
 
174
 
175
  reloaded_bytes = 0
176
 
177
- # Re-allocate all GPU storages first.
178
- for t in self._managed:
179
- local = self._local(t)
180
- storage = local.untyped_storage()
181
- if storage.size() != 0:
182
- raise RuntimeError(
183
- f"Storage should have been freed (size=0) before reload, "
184
- f"but got size={storage.size()}. "
185
- f"Tensor shape: {t.shape}, dtype: {t.dtype}"
186
- )
187
- storage.resize_(self._storage_nbytes[id(t)])
188
-
189
- # Per-tensor H2D copies from CPU flat buffer slices.
190
- # non_blocking=True with pinned source allows DMA overlap.
191
- for dtype, grp in self._groups.items():
192
- indices = grp["indices"]
193
- offsets = grp["offsets"]
194
- cpu_flat = grp["cpu_flat"]
195
-
196
- for i, mgd_idx in enumerate(indices):
197
- local = self._local(self._managed[mgd_idx])
198
- off, n = offsets[i]
199
- local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
200
-
201
- reloaded_bytes += grp["total"] * cpu_flat.element_size()
 
 
 
 
 
 
 
 
 
 
202
 
203
  if not self._logged:
204
  logger.info(
205
- "[CPUOffload] Reloaded %.2f MB (CPU → GPU)", reloaded_bytes / (1024**2)
 
206
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  import torch
22
  from torch.distributed.tensor import DTensor
23
+ from torch.profiler import record_function
24
 
25
  logger = logging.getLogger(__name__)
26
 
 
36
  def __init__(self):
37
  self._managed: list[torch.Tensor] = []
38
  self._storage_nbytes: dict[int, int] = {} # id(t) → bytes
39
+ # Optional tag → managed-indices map for group-wise reload
40
+ # (e.g. per-layer lockstep reload driven by backward hooks).
41
+ self._tag_to_indices: dict[str, list[int]] = {}
42
 
43
  # Per-dtype group: populated on first offload.
44
  # dtype → dict with keys:
 
49
  self._groups: dict[torch.dtype, dict] = {}
50
 
51
  self._offload_stream: torch.cuda.Stream | None = None
52
+ self._reload_stream: torch.cuda.Stream | None = None
53
+ self._reload_event: torch.cuda.Event | None = None
54
  self._device: torch.device | None = None
55
  self._initialized: bool = False
56
  self._logged: bool = False
 
65
  if self._offload_stream is None:
66
  self._offload_stream = torch.cuda.Stream(device=self._device)
67
 
68
+ def _ensure_reload_stream(self):
69
+ if self._reload_stream is None:
70
+ least_priority, _ = torch.cuda.Stream.priority_range()
71
+ self._reload_stream = torch.cuda.Stream(
72
+ device=self._device,
73
+ priority=least_priority,
74
+ )
75
+ logger.info(
76
+ "[CPUOffload] reload stream created with priority=%d "
77
+ "(range: %d..%d)",
78
+ least_priority,
79
+ *torch.cuda.Stream.priority_range(),
80
+ )
81
+
82
  # ------------------------------------------------------------------
83
+ def track(self, tensor: torch.Tensor, tag: str | None = None):
84
+ """Register a GPU tensor for CPU offloading. Idempotent.
85
+
86
+ If ``tag`` is given, the tensor's managed index is recorded under
87
+ that tag so callers can trigger a partial reload via
88
+ :meth:`reload_group`.
89
+ """
90
  tid = id(tensor)
91
  if tid in self._storage_nbytes:
92
  return
 
98
  if storage.size() == 0:
99
  return
100
  self._storage_nbytes[tid] = storage.size()
101
+ idx = len(self._managed)
102
  self._managed.append(tensor)
103
+ if tag is not None:
104
+ self._tag_to_indices.setdefault(tag, []).append(idx)
105
 
106
  # ------------------------------------------------------------------
107
  def _init_buffers(self):
 
121
  indices.append(idx)
122
  offsets.append((off, n))
123
  off += n
124
+ cpu_flat = torch.empty(off,
125
+ dtype=dtype,
126
+ device="cpu",
127
+ pin_memory=True)
128
  self._groups[dtype] = {
129
  "indices": indices,
130
  "offsets": offsets,
 
168
  for i, mgd_idx in enumerate(indices):
169
  local = self._local(self._managed[mgd_idx])
170
  off, n = offsets[i]
171
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
172
+ non_blocking=True)
173
 
174
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
175
 
 
183
  raise RuntimeError(
184
  f"Tensor storage is already freed (size=0) before offload. "
185
  f"This indicates a double-free or external interference. "
186
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}")
 
187
 
188
  if not self._logged:
189
  logger.info(
 
193
 
194
  # ------------------------------------------------------------------
195
  def reload(self):
196
+ """Per-tensor H2D from CPU flat buffer.
197
+
198
+ Storage re-allocation (``resize_``) runs on the current (default)
199
+ stream. H2D copies run on a dedicated ``_reload_stream``.
200
 
201
+ Call :meth:`wait_reload` before consuming the reloaded tensors.
 
 
 
202
  """
203
  if not self._managed or not self._initialized:
204
  return
205
+ self._ensure_reload_stream()
206
 
207
  reloaded_bytes = 0
208
 
209
+ # Re-allocate all GPU storages with per-tensor profiling.
210
+ with record_function("CPUOffload::resize_storages"):
211
+ for i, t in enumerate(self._managed):
212
+ local = self._local(t)
213
+ storage = local.untyped_storage()
214
+ if storage.size() != 0:
215
+ raise RuntimeError(
216
+ f"Storage should have been freed (size=0) before reload, "
217
+ f"but got size={storage.size()}. "
218
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}")
219
+ nbytes = self._storage_nbytes[id(t)]
220
+ with record_function(f"resize_[{i}]_{nbytes // 1024}KB"):
221
+ storage.resize_(nbytes)
222
+
223
+ # Reload stream waits for the resize_ ops to finish.
224
+ alloc_event = torch.cuda.current_stream(self._device).record_event()
225
+ self._reload_stream.wait_event(alloc_event)
226
+
227
+ # Per-tensor H2D copies on the reload stream.
228
+ with record_function("CPUOffload::h2d_copies"):
229
+ with torch.cuda.stream(self._reload_stream):
230
+ for dtype, grp in self._groups.items():
231
+ indices = grp["indices"]
232
+ offsets = grp["offsets"]
233
+ cpu_flat = grp["cpu_flat"]
234
+
235
+ for i, mgd_idx in enumerate(indices):
236
+ local = self._local(self._managed[mgd_idx])
237
+ off, n = offsets[i]
238
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
239
+ non_blocking=True)
240
+
241
+ reloaded_bytes += grp["total"] * cpu_flat.element_size()
242
+
243
+ self._reload_event = self._reload_stream.record_event()
244
 
245
  if not self._logged:
246
  logger.info(
247
+ "[CPUOffload] Reloaded %.2f MB (CPU → GPU, async)",
248
+ reloaded_bytes / (1024**2),
249
  )
250
+ self._logged = True
251
+
252
+ def reload_group(self, tag: str, sync_streams: tuple = ()):
253
+ """Reload only the managed tensors registered under ``tag``.
254
+
255
+ Intended for layer-lockstep overlap: backward frees a layer's
256
+ activations, then the backward hook calls ``reload_group`` with
257
+ that layer's tag so the H2D copy reuses the freshly-freed memory
258
+ from the default stream's allocator pool.
259
+
260
+ ``sync_streams`` is an optional iterable of CUDA streams whose
261
+ currently-queued work must complete before the H2D memcpy runs.
262
+ This is used to avoid allocator cross-stream reuse races under
263
+ ``expandable_segments``: if a just-freed block was last used on
264
+ FSDP's all-gather stream, making the reload stream wait on that
265
+ stream guarantees FIFO ordering between the block's prior use
266
+ and our H2D write.
267
+ """
268
+ if not self._managed or not self._initialized:
269
+ return
270
+ indices = self._tag_to_indices.get(tag)
271
+ if not indices:
272
+ return
273
+ self._ensure_reload_stream()
274
+
275
+ # Sync reload_stream with the supplied streams (e.g. FSDP AG
276
+ # streams) before we queue any H2D: ensures past uses of any
277
+ # allocator block we're about to reuse are fully drained.
278
+ for s in sync_streams:
279
+ if s is not None:
280
+ self._reload_stream.wait_stream(s)
281
+
282
+ idx_set = set(indices)
283
+
284
+ with record_function(f"CPUOffload::group_resize[{tag}]"):
285
+ for i in indices:
286
+ t = self._managed[i]
287
+ local = self._local(t)
288
+ storage = local.untyped_storage()
289
+ if storage.size() == 0:
290
+ storage.resize_(self._storage_nbytes[id(t)])
291
+
292
+ alloc_event = torch.cuda.current_stream(self._device).record_event()
293
+ self._reload_stream.wait_event(alloc_event)
294
+
295
+ with record_function(f"CPUOffload::group_h2d[{tag}]"):
296
+ with torch.cuda.stream(self._reload_stream):
297
+ for dtype, grp in self._groups.items():
298
+ indices_grp = grp["indices"]
299
+ offsets = grp["offsets"]
300
+ cpu_flat = grp["cpu_flat"]
301
+
302
+ for i, mgd_idx in enumerate(indices_grp):
303
+ if mgd_idx not in idx_set:
304
+ continue
305
+ local = self._local(self._managed[mgd_idx])
306
+ off, n = offsets[i]
307
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
308
+ non_blocking=True)
309
+
310
+ self._reload_event = self._reload_stream.record_event()
311
+
312
+ def reload_untagged(self):
313
+ """Reload managed tensors that were not registered under any tag.
314
+
315
+ Useful when a subset of params (e.g. MoE experts) is driven via
316
+ per-tag layer-lockstep hooks while the remainder should still be
317
+ reloaded before optimizer.step() in a single bulk call.
318
+ """
319
+ if not self._managed or not self._initialized:
320
+ return
321
+ tagged: set[int] = set()
322
+ for idx_list in self._tag_to_indices.values():
323
+ tagged.update(idx_list)
324
+ untagged = [i for i in range(len(self._managed)) if i not in tagged]
325
+ if not untagged:
326
+ return
327
+ self._ensure_reload_stream()
328
+
329
+ idx_set = set(untagged)
330
+
331
+ with record_function("CPUOffload::untagged_resize"):
332
+ for i in untagged:
333
+ t = self._managed[i]
334
+ local = self._local(t)
335
+ storage = local.untyped_storage()
336
+ if storage.size() == 0:
337
+ storage.resize_(self._storage_nbytes[id(t)])
338
+
339
+ alloc_event = torch.cuda.current_stream(self._device).record_event()
340
+ self._reload_stream.wait_event(alloc_event)
341
+
342
+ with record_function("CPUOffload::untagged_h2d"):
343
+ with torch.cuda.stream(self._reload_stream):
344
+ for dtype, grp in self._groups.items():
345
+ indices_grp = grp["indices"]
346
+ offsets = grp["offsets"]
347
+ cpu_flat = grp["cpu_flat"]
348
+
349
+ for i, mgd_idx in enumerate(indices_grp):
350
+ if mgd_idx not in idx_set:
351
+ continue
352
+ local = self._local(self._managed[mgd_idx])
353
+ off, n = offsets[i]
354
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
355
+ non_blocking=True)
356
+
357
+ self._reload_event = self._reload_stream.record_event()
358
+
359
+ def wait_reload(self):
360
+ """Block the current (default) stream until reload H2D completes."""
361
+ if self._reload_event is not None:
362
+ torch.cuda.current_stream(self._device).wait_event(
363
+ self._reload_event)
364
+ self._reload_event = None
torch-ext/optimizer/muon.py CHANGED
@@ -242,8 +242,12 @@ class Muon(torch.optim.Optimizer):
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
  self.cpu_offload = False
 
245
  self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
 
 
 
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
249
 
@@ -955,15 +959,16 @@ class Muon(torch.optim.Optimizer):
955
  if p not in self.state:
956
  continue
957
  state = self.state[p]
 
958
  if group.get("use_muon", False):
959
  if "momentum_buffer" in state:
960
- pool.track(state["momentum_buffer"])
961
  tracked += 1
962
  else:
963
  if "moment1" in state:
964
- pool.track(state["moment1"])
965
  if "moment2" in state:
966
- pool.track(state["moment2"])
967
  tracked += 1
968
  logger.info("[CPUOffload] Registered %d param states for offload",
969
  tracked)
@@ -986,8 +991,10 @@ class Muon(torch.optim.Optimizer):
986
  loss = closure()
987
 
988
  # H2D: reload optimizer states from CPU before computation.
989
- if self.cpu_offload and self._offload_initialized:
990
- self._cpu_offload_pool.reload()
 
 
991
 
992
  logger.debug("[Muon.step] expert_keys=%s, %d param groups",
993
  self.expert_keys, len(self.param_groups))
@@ -1004,6 +1011,53 @@ class Muon(torch.optim.Optimizer):
1004
  step_adamw(self.state, group)
1005
 
1006
  # D2H: offload optimizer states to CPU after computation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
  if self._cpu_offload_pool is None:
@@ -1012,12 +1066,6 @@ class Muon(torch.optim.Optimizer):
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
1014
 
1015
- return loss
1016
-
1017
- # ------------------------------------------------------------------
1018
- # CPU offload public helpers
1019
- # ------------------------------------------------------------------
1020
-
1021
  def turn_on_cpu_offload(self):
1022
  """Enable CPU offload for optimizer states."""
1023
  if self.cpu_offload:
@@ -1039,6 +1087,7 @@ class Muon(torch.optim.Optimizer):
1039
  logger.info("[Muon] turn_off_cpu_offload")
1040
  if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
 
1042
  torch.cuda.current_stream().synchronize()
1043
  self._cpu_offload_pool = None
1044
  self._offload_initialized = False
 
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
  self.cpu_offload = False
245
+ self.manual_offload = False
246
  self._cpu_offload_pool: CPUOffloadPool | None = None
247
  self._offload_initialized = False
248
+ # id(param) -> tag, consumed by _register_states_for_offload so the
249
+ # offload pool can do group-wise reload (e.g. per-layer lockstep).
250
+ self._param_tags: dict[int, str] = {}
251
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
252
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
253
 
 
959
  if p not in self.state:
960
  continue
961
  state = self.state[p]
962
+ tag = self._param_tags.get(id(p))
963
  if group.get("use_muon", False):
964
  if "momentum_buffer" in state:
965
+ pool.track(state["momentum_buffer"], tag=tag)
966
  tracked += 1
967
  else:
968
  if "moment1" in state:
969
+ pool.track(state["moment1"], tag=tag)
970
  if "moment2" in state:
971
+ pool.track(state["moment2"], tag=tag)
972
  tracked += 1
973
  logger.info("[CPUOffload] Registered %d param states for offload",
974
  tracked)
 
991
  loss = closure()
992
 
993
  # H2D: reload optimizer states from CPU before computation.
994
+ if not self.manual_offload:
995
+ if self.cpu_offload and self._offload_initialized:
996
+ self._cpu_offload_pool.reload()
997
+ self._cpu_offload_pool.wait_reload()
998
 
999
  logger.debug("[Muon.step] expert_keys=%s, %d param groups",
1000
  self.expert_keys, len(self.param_groups))
 
1011
  step_adamw(self.state, group)
1012
 
1013
  # D2H: offload optimizer states to CPU after computation.
1014
+ if not self.manual_offload:
1015
+ if self.cpu_offload:
1016
+ if not self._offload_initialized:
1017
+ if self._cpu_offload_pool is None:
1018
+ self._cpu_offload_pool = CPUOffloadPool()
1019
+ self._register_states_for_offload()
1020
+ self._offload_initialized = True
1021
+ self._cpu_offload_pool.offload()
1022
+
1023
+ return loss
1024
+
1025
+ # ------------------------------------------------------------------
1026
+ # CPU offload public helpers
1027
+ # ------------------------------------------------------------------
1028
+
1029
+ def reload_group(self, tag: str, sync_streams: tuple = ()):
1030
+ """Reload optimizer states registered under ``tag``.
1031
+
1032
+ Tags are set via :meth:`set_param_tags` before the first step.
1033
+ ``sync_streams`` forwards to :meth:`CPUOffloadPool.reload_group`
1034
+ so callers (e.g. FSDP pre/post-hook patches) can make the reload
1035
+ stream wait on collective streams before its H2D runs.
1036
+ """
1037
+ if self.cpu_offload and self._offload_initialized:
1038
+ self._cpu_offload_pool.reload_group(tag, sync_streams=sync_streams)
1039
+
1040
+ def reload_untagged(self):
1041
+ """Reload all optimizer states not attached to any tag."""
1042
+ if self.cpu_offload and self._offload_initialized:
1043
+ self._cpu_offload_pool.reload_untagged()
1044
+
1045
+ def set_param_tags(self, param_tags: dict[int, str]) -> None:
1046
+ """Attach an ``id(param) -> tag`` mapping for group-wise reload.
1047
+
1048
+ Must be called before the first ``step()`` (i.e. before
1049
+ :meth:`_register_states_for_offload`) so the pool receives tags
1050
+ when states are first registered.
1051
+ """
1052
+ self._param_tags = dict(param_tags)
1053
+
1054
+ def wait_reload(self):
1055
+ """Block the default stream until the async reload completes."""
1056
+ if self.cpu_offload and self._offload_initialized:
1057
+ self._cpu_offload_pool.wait_reload()
1058
+
1059
+ def offload(self):
1060
+ """Offload optimizer states from GPU to CPU (D2H)."""
1061
  if self.cpu_offload:
1062
  if not self._offload_initialized:
1063
  if self._cpu_offload_pool is None:
 
1066
  self._offload_initialized = True
1067
  self._cpu_offload_pool.offload()
1068
 
 
 
 
 
 
 
1069
  def turn_on_cpu_offload(self):
1070
  """Enable CPU offload for optimizer states."""
1071
  if self.cpu_offload:
 
1087
  logger.info("[Muon] turn_off_cpu_offload")
1088
  if self._offload_initialized:
1089
  self._cpu_offload_pool.reload()
1090
+ self._cpu_offload_pool.wait_reload()
1091
  torch.cuda.current_stream().synchronize()
1092
  self._cpu_offload_pool = None
1093
  self._offload_initialized = False
torch-ext/optimizer/newton_schulz.py CHANGED
@@ -32,30 +32,28 @@ def _optimal_quintic(l, u, max_iter=1000):
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
- LHS = np.array(
36
- [
37
- [l, l**3, l**5, 1],
38
- [q, q**3, q**5, -1],
39
- [r, r**3, r**5, 1],
40
- [u, u**3, u**5, -1],
41
- ]
42
- )
43
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
  if not np.all(np.isfinite([a, b, c, E])):
45
  raise ValueError(
46
  f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
  )
48
  q, r = np.sqrt(
49
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
- )
51
  if not np.all(np.isfinite([q, r])):
52
- raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
 
53
  if abs(old_E - E) <= 1e-15:
54
  break
55
  else:
56
  raise RuntimeError(
57
- f"_optimal_quintic: did not converge after {max_iter} iterations"
58
- )
59
  return float(a), float(b), float(c)
60
 
61
 
@@ -114,9 +112,10 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
114
  # - Polar Express: analytically optimal per step, adapting to the shrinking
115
  # singular-value interval [l, u] as iterations progress; converges all
116
  # singular values to 1, producing the exact polar factor UV^T.
117
- _coeffs_list = _optimal_composition(
118
- l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
- )
 
120
 
121
 
122
  # This code is adapted from:
@@ -150,8 +149,7 @@ def _zeropower_via_newtonschulz5(G, steps):
150
 
151
  X = X / (X.norm() + 1e-7)
152
  hs = _coeffs_list[:steps] + list(
153
- repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
- )
155
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
  # Perform the NS iterations
@@ -186,8 +184,7 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
186
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
 
188
  hs = _coeffs_list[:steps] + list(
189
- repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
- )
191
  for a, b, c in hs:
192
  buf1 = torch.bmm(X, X.transpose(-2, -1))
193
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
 
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
+ LHS = np.array([
36
+ [l, l**3, l**5, 1],
37
+ [q, q**3, q**5, -1],
38
+ [r, r**3, r**5, 1],
39
+ [u, u**3, u**5, -1],
40
+ ])
 
 
41
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
  if not np.all(np.isfinite([a, b, c, E])):
43
  raise ValueError(
44
  f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
45
  )
46
  q, r = np.sqrt(
47
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
48
+ (10 * c))
49
  if not np.all(np.isfinite([q, r])):
50
+ raise ValueError(
51
+ f"_optimal_quintic: non-finite node update q={q}, r={r}")
52
  if abs(old_E - E) <= 1e-15:
53
  break
54
  else:
55
  raise RuntimeError(
56
+ f"_optimal_quintic: did not converge after {max_iter} iterations")
 
57
  return float(a), float(b), float(c)
58
 
59
 
 
112
  # - Polar Express: analytically optimal per step, adapting to the shrinking
113
  # singular-value interval [l, u] as iterations progress; converges all
114
  # singular values to 1, producing the exact polar factor UV^T.
115
+ _coeffs_list = _optimal_composition(l=1e-3,
116
+ num_iters=10,
117
+ safety_factor_eps=1e-2,
118
+ cushion=0.02)
119
 
120
 
121
  # This code is adapted from:
 
149
 
150
  X = X / (X.norm() + 1e-7)
151
  hs = _coeffs_list[:steps] + list(
152
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
153
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
155
  # Perform the NS iterations
 
184
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
185
 
186
  hs = _coeffs_list[:steps] + list(
187
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
188
  for a, b, c in hs:
189
  buf1 = torch.bmm(X, X.transpose(-2, -1))
190
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))