Kernels
wyldecat Claude Opus 4.6 commited on
Commit
cd587a6
·
1 Parent(s): 7e33533

Add PP + dp_replicate deadlock regression tests [skip-build]

Browse files

Add tests simulating Pipeline Parallelism with dp_replicate > 1 to verify
construct_shard_mesh doesn't deadlock when different PP stages call
dist.new_group for different parameters independently.

- test_muon.py: Dense model test (PP=2, dp_replicate=2, dp_shard=2)
- test_muon_moe.py: MoE test with asymmetric expert/non-expert params
- test_pp.yaml: K8s job spec for running on GPU pod

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. test/test_muon.py +121 -0
  2. test/test_muon_moe.py +130 -0
test/test_muon.py CHANGED
@@ -391,3 +391,124 @@ def test_parallel_muon_uneven_shard(init_dist, uneven_dim):
391
  set_ns_compile(True)
392
  logger.info("test_parallel_muon_uneven_shard (dim=%d) PASSED (rank %d)",
393
  uneven_dim, rank)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  set_ns_compile(True)
392
  logger.info("test_parallel_muon_uneven_shard (dim=%d) PASSED (rank %d)",
393
  uneven_dim, rank)
394
+
395
+
396
+ def test_pp_dp_replicate_no_deadlock(init_dist):
397
+ """Regression: PP-like setup where different rank subsets call
398
+ construct_shard_mesh for different parameters must not deadlock.
399
+
400
+ Simulates PP=2 with dp_replicate=2, dp_shard=2. Each PP stage has
401
+ 4 ranks with a (2,2) mesh and [Replicate, Shard(0)] placements.
402
+ Stages create different numbers of parameters, forcing
403
+ construct_shard_mesh to be called independently per stage.
404
+ Without use_local_synchronization=True in dist.new_group(),
405
+ this would deadlock.
406
+ """
407
+ from optimizer.distributed.utils import _ranks_to_dist_cache
408
+ from optimizer.newton_schulz import set_ns_compile
409
+
410
+ rank = dist.get_rank()
411
+ world_size = dist.get_world_size()
412
+ assert world_size == 8
413
+
414
+ set_ns_compile(False)
415
+
416
+ # Clear cache to ensure dist.new_group is actually called
417
+ _ranks_to_dist_cache.clear()
418
+
419
+ # Create full mesh: PP=2, dp_replicate=2, dp_shard=2
420
+ full_mesh = dist.init_device_mesh(
421
+ "cuda",
422
+ (2, 2, 2),
423
+ mesh_dim_names=("pp", "dp_replicate", "dp_shard"),
424
+ )
425
+
426
+ # Per-stage submesh (shape (2,2), 4 ranks each)
427
+ stage_mesh = full_mesh["dp_replicate", "dp_shard"]
428
+ pp_rank = full_mesh.get_local_rank("pp")
429
+
430
+ torch.manual_seed(42 + pp_rank)
431
+
432
+ # Asymmetric param counts: stage 0 gets 3, stage 1 gets 5
433
+ num_params = 3 if pp_rank == 0 else 5
434
+ placements = [Replicate(), Shard(0)]
435
+
436
+ muon_params = []
437
+ muon_names = []
438
+ full_params_snapshot = []
439
+ full_grads = []
440
+
441
+ for i in range(num_params):
442
+ full = torch.randn(32, 64, device="cuda")
443
+ full_params_snapshot.append(full.clone())
444
+ dt = distribute_tensor(full, stage_mesh, placements)
445
+ p = torch.nn.Parameter(dt)
446
+ grad_full = torch.randn(32, 64, device="cuda")
447
+ full_grads.append(grad_full.clone())
448
+ p.grad = distribute_tensor(grad_full, stage_mesh, placements)
449
+ muon_params.append(p)
450
+ muon_names.append(f"stage{pp_rank}.layer.{i}.weight")
451
+
452
+ param_groups = [{
453
+ "params": muon_params,
454
+ "names": muon_names,
455
+ "use_muon": True,
456
+ "lr": 0.02,
457
+ "weight_decay": 0.01,
458
+ "momentum": 0.95,
459
+ "nesterov": True,
460
+ "ns_steps": 5,
461
+ "none_grad": False,
462
+ }]
463
+
464
+ # Must not deadlock
465
+ optim = Muon(params=param_groups, chunk_size=1, warmup_step=0)
466
+ optim.step()
467
+
468
+ # Second step to verify cached path
469
+ for p in muon_params:
470
+ grad_full = torch.randn(32, 64, device="cuda")
471
+ p.grad = distribute_tensor(grad_full, stage_mesh, placements)
472
+ optim.step()
473
+
474
+ # Correctness: compare against sequential baseline
475
+ seq_params = []
476
+ for fp, fg in zip(full_params_snapshot, full_grads):
477
+ p = torch.nn.Parameter(fp.clone())
478
+ p.grad = fg.clone()
479
+ seq_params.append(p)
480
+
481
+ param_groups_seq = [{
482
+ "params":
483
+ seq_params,
484
+ "names":
485
+ [f"stage{pp_rank}.layer.{i}.weight" for i in range(num_params)],
486
+ "use_muon":
487
+ True,
488
+ "lr":
489
+ 0.02,
490
+ "weight_decay":
491
+ 0.01,
492
+ "momentum":
493
+ 0.95,
494
+ "nesterov":
495
+ True,
496
+ "ns_steps":
497
+ 5,
498
+ "none_grad":
499
+ False,
500
+ }]
501
+ optim_seq = Muon(params=param_groups_seq)
502
+ optim_seq.step()
503
+
504
+ for i in range(num_params):
505
+ par_full = muon_params[i].data.full_tensor()
506
+ torch.testing.assert_close(par_full,
507
+ seq_params[i].data,
508
+ atol=0,
509
+ rtol=0)
510
+
511
+ set_ns_compile(True)
512
+ logger.info(
513
+ "test_pp_dp_replicate_no_deadlock PASSED (rank %d, pp_rank %d)", rank,
514
+ pp_rank)
test/test_muon_moe.py CHANGED
@@ -402,3 +402,133 @@ def test_parallel_muon_moe_uneven_shard(init_dist, uneven_dim):
402
  logger.info(
403
  "test_parallel_muon_moe_uneven_shard (dim=%d) PASSED (rank %d)",
404
  uneven_dim, rank)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  logger.info(
403
  "test_parallel_muon_moe_uneven_shard (dim=%d) PASSED (rank %d)",
404
  uneven_dim, rank)
405
+
406
+
407
+ def test_pp_dp_replicate_moe_no_deadlock(init_dist):
408
+ """Regression: PP-like MoE setup where different stages have different
409
+ parameter types must not deadlock in construct_shard_mesh.
410
+
411
+ Simulates PP=2 with dp_replicate=2, dp_shard=2. Stage 0 has only
412
+ non-expert 2D DTensor params; stage 1 has non-expert 2D DTensor params
413
+ plus 3D expert plain-tensor params. This mirrors real PP+MoE where
414
+ expert layers exist only in certain stages.
415
+ """
416
+ from optimizer.distributed.utils import _ranks_to_dist_cache
417
+ from optimizer.newton_schulz import set_ns_compile
418
+
419
+ rank = dist.get_rank()
420
+ world_size = dist.get_world_size()
421
+ assert world_size == 8
422
+
423
+ set_ns_compile(False)
424
+
425
+ # Clear cache to ensure dist.new_group is actually called
426
+ _ranks_to_dist_cache.clear()
427
+
428
+ # Create full mesh: PP=2, dp_replicate=2, dp_shard=2
429
+ full_mesh = dist.init_device_mesh(
430
+ "cuda",
431
+ (2, 2, 2),
432
+ mesh_dim_names=("pp", "dp_replicate", "dp_shard"),
433
+ )
434
+
435
+ stage_mesh = full_mesh["dp_replicate", "dp_shard"]
436
+ pp_rank = full_mesh.get_local_rank("pp")
437
+
438
+ torch.manual_seed(42 + pp_rank)
439
+
440
+ placements = [Replicate(), Shard(0)]
441
+ num_experts = 4
442
+
443
+ muon_params = []
444
+ muon_names = []
445
+ full_params_snapshot = []
446
+ full_grads = []
447
+
448
+ # Non-expert 2D DTensor params (both stages, different counts)
449
+ num_dense = 2 if pp_rank == 0 else 3
450
+ for i in range(num_dense):
451
+ full = torch.randn(32, 64, device="cuda")
452
+ full_params_snapshot.append(full.clone())
453
+ dt = distribute_tensor(full, stage_mesh, placements)
454
+ p = torch.nn.Parameter(dt)
455
+ g = torch.randn(32, 64, device="cuda")
456
+ full_grads.append(g.clone())
457
+ p.grad = distribute_tensor(g, stage_mesh, placements)
458
+ muon_params.append(p)
459
+ muon_names.append(f"stage{pp_rank}.layers.{i}.weight")
460
+
461
+ # Stage 1 only: 3D expert plain-tensor params
462
+ if pp_rank == 1:
463
+ full = torch.randn(num_experts, 32, 64, device="cuda")
464
+ full_params_snapshot.append(full.clone())
465
+ p = torch.nn.Parameter(full)
466
+ g = torch.randn(num_experts, 32, 64, device="cuda")
467
+ full_grads.append(g.clone())
468
+ p.grad = g
469
+ muon_params.append(p)
470
+ muon_names.append(
471
+ f"stage{pp_rank}.layers.{num_dense}.experts.w1.weight")
472
+
473
+ param_groups = [{
474
+ "params": muon_params,
475
+ "names": muon_names,
476
+ "use_muon": True,
477
+ "lr": 0.02,
478
+ "weight_decay": 0.01,
479
+ "momentum": 0.95,
480
+ "nesterov": True,
481
+ "ns_steps": 5,
482
+ "none_grad": False,
483
+ }]
484
+
485
+ # Must not deadlock
486
+ optim = Muon(params=param_groups,
487
+ chunk_size=1,
488
+ warmup_step=0,
489
+ expert_keys=["experts"])
490
+ optim.step()
491
+
492
+ # Second step to verify cached path
493
+ for p in muon_params:
494
+ if isinstance(p.data, DTensor):
495
+ g = torch.randn(32, 64, device="cuda")
496
+ p.grad = distribute_tensor(g, stage_mesh, placements)
497
+ else:
498
+ p.grad = torch.randn_like(p.data)
499
+ optim.step()
500
+
501
+ # Correctness: compare against sequential baseline
502
+ seq_params = []
503
+ for fp in full_params_snapshot:
504
+ seq_params.append(torch.nn.Parameter(fp.clone()))
505
+ for p, g in zip(seq_params, full_grads):
506
+ p.grad = g.clone()
507
+
508
+ param_groups_seq = [{
509
+ "params": seq_params,
510
+ "names": list(muon_names),
511
+ "use_muon": True,
512
+ "lr": 0.02,
513
+ "weight_decay": 0.01,
514
+ "momentum": 0.95,
515
+ "nesterov": True,
516
+ "ns_steps": 5,
517
+ "none_grad": False,
518
+ }]
519
+ optim_seq = Muon(params=param_groups_seq, expert_keys=["experts"])
520
+ optim_seq.step()
521
+
522
+ for i in range(len(muon_params)):
523
+ par_data = muon_params[i].data
524
+ if isinstance(par_data, DTensor):
525
+ par_data = par_data.full_tensor()
526
+ torch.testing.assert_close(par_data,
527
+ seq_params[i].data,
528
+ atol=0,
529
+ rtol=0)
530
+
531
+ set_ns_compile(True)
532
+ logger.info(
533
+ "test_pp_dp_replicate_moe_no_deadlock PASSED (rank %d, pp_rank %d)",
534
+ rank, pp_rank)