Kernels
wyldecat Claude Opus 4.6 commited on
Commit
a4d1f34
·
1 Parent(s): c0bbf2e

Add correctness verification to PP tests using fully_shard [skip-build]

Browse files

Use fully_shard (proven HSDP pattern) instead of manual distribute_tensor
to create proper DTensors. Verify parallel results match sequential
baseline with atol=0, rtol=0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. test/test_muon.py +72 -25
  2. test/test_muon_moe.py +95 -30
test/test_muon.py CHANGED
@@ -396,16 +396,19 @@ def test_parallel_muon_uneven_shard(init_dist, uneven_dim):
396
  def test_pp_dp_replicate_no_deadlock(init_dist):
397
  """Regression: PP-like setup where different rank subsets call
398
  construct_shard_mesh for different parameters must not deadlock.
 
399
 
400
  Simulates PP=2 with dp_replicate=2, dp_shard=2. Each PP stage has
401
- 4 ranks with a (2,2) mesh and [Replicate, Shard(0)] placements.
402
- Stages create different numbers of parameters, forcing
 
403
  construct_shard_mesh to be called independently per stage.
404
  Without use_local_synchronization=True in dist.new_group(),
405
  this would deadlock.
406
  """
407
  from optimizer.distributed.utils import _ranks_to_dist_cache
408
  from optimizer.newton_schulz import set_ns_compile
 
409
 
410
  rank = dist.get_rank()
411
  world_size = dist.get_world_size()
@@ -423,28 +426,47 @@ def test_pp_dp_replicate_no_deadlock(init_dist):
423
  mesh_dim_names=("pp", "dp_replicate", "dp_shard"),
424
  )
425
 
426
- # Per-stage submesh (shape (2,2), 4 ranks each)
427
  stage_mesh = full_mesh["dp_replicate", "dp_shard"]
428
  pp_rank = full_mesh.get_local_rank("pp")
429
 
430
- torch.manual_seed(42 + pp_rank)
431
-
432
- # Asymmetric param counts: stage 0 gets 3, stage 1 gets 5
433
- num_params = 3 if pp_rank == 0 else 5
434
- placements = [Replicate(), Shard(0)]
435
 
436
- muon_params = []
437
- muon_names = []
438
 
439
- for i in range(num_params):
440
- full = torch.randn(32, 64, device="cuda")
441
- dt = distribute_tensor(full, stage_mesh, placements)
442
- p = torch.nn.Parameter(dt)
443
- p.grad = distribute_tensor(torch.randn(32, 64, device="cuda"),
444
- stage_mesh, placements)
445
- muon_params.append(p)
446
- muon_names.append(f"stage{pp_rank}.layer.{i}.weight")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
 
 
 
448
  param_groups = [{
449
  "params": muon_params,
450
  "names": muon_names,
@@ -456,16 +478,41 @@ def test_pp_dp_replicate_no_deadlock(init_dist):
456
  "ns_steps": 5,
457
  "none_grad": False,
458
  }]
459
-
460
- # Must not deadlock
461
  optim = Muon(params=param_groups, chunk_size=1, warmup_step=0)
462
  optim.step()
463
 
464
- # Second step to verify cached path
465
- for p in muon_params:
466
- p.grad = distribute_tensor(torch.randn(32, 64, device="cuda"),
467
- stage_mesh, placements)
468
- optim.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  set_ns_compile(True)
471
  logger.info(
 
396
  def test_pp_dp_replicate_no_deadlock(init_dist):
397
  """Regression: PP-like setup where different rank subsets call
398
  construct_shard_mesh for different parameters must not deadlock.
399
+ Also verifies correctness (atol=0, rtol=0) against sequential baseline.
400
 
401
  Simulates PP=2 with dp_replicate=2, dp_shard=2. Each PP stage has
402
+ 4 ranks with a (2,2) mesh and [Replicate, Shard(0)] placements
403
+ (created via fully_shard, matching the real HSDP pattern).
404
+ Stages create different numbers of layers, forcing
405
  construct_shard_mesh to be called independently per stage.
406
  Without use_local_synchronization=True in dist.new_group(),
407
  this would deadlock.
408
  """
409
  from optimizer.distributed.utils import _ranks_to_dist_cache
410
  from optimizer.newton_schulz import set_ns_compile
411
+ from torch.distributed.fsdp import fully_shard
412
 
413
  rank = dist.get_rank()
414
  world_size = dist.get_world_size()
 
426
  mesh_dim_names=("pp", "dp_replicate", "dp_shard"),
427
  )
428
 
 
429
  stage_mesh = full_mesh["dp_replicate", "dp_shard"]
430
  pp_rank = full_mesh.get_local_rank("pp")
431
 
432
+ # Asymmetric layer counts per stage (mimics PP)
433
+ num_layers = 3 if pp_rank == 0 else 5
434
+ hidden = 64
 
 
435
 
436
+ # Same seed per stage so all ranks in a stage get identical init weights
437
+ torch.manual_seed(42 + pp_rank)
438
 
439
+ # Create model and save initial state for sequential baseline
440
+ model = torch.nn.Sequential(*[
441
+ torch.nn.Linear(hidden, hidden, bias=False) for _ in range(num_layers)
442
+ ]).cuda()
443
+
444
+ init_state = {n: p.data.clone() for n, p in model.named_parameters()}
445
+ grads = {n: torch.randn_like(p) for n, p in model.named_parameters()}
446
+
447
+ # Apply FSDP (creates proper DTensors with [Replicate, Shard(0)])
448
+ for layer in model:
449
+ fully_shard(layer, mesh=stage_mesh)
450
+ fully_shard(model, mesh=stage_mesh)
451
+ model.reshard()
452
+
453
+ # Apply grads with proper DTensor redistribution
454
+ for n, p in model.named_parameters():
455
+ g = grads[n]
456
+ if isinstance(p.data, DTensor):
457
+ ug = DTensor.from_local(
458
+ g,
459
+ device_mesh=p.data.device_mesh,
460
+ placements=[Replicate()] * p.data.device_mesh.ndim,
461
+ )
462
+ p.grad = ug.redistribute(device_mesh=p.data.device_mesh,
463
+ placements=p.data.placements)
464
+ else:
465
+ p.grad = g
466
 
467
+ # Parallel Muon step — must not deadlock
468
+ muon_names = [n for n, _ in model.named_parameters()]
469
+ muon_params = [p for _, p in model.named_parameters()]
470
  param_groups = [{
471
  "params": muon_params,
472
  "names": muon_names,
 
478
  "ns_steps": 5,
479
  "none_grad": False,
480
  }]
 
 
481
  optim = Muon(params=param_groups, chunk_size=1, warmup_step=0)
482
  optim.step()
483
 
484
+ # Sequential baseline (base path, no sharding)
485
+ torch.manual_seed(42 + pp_rank)
486
+ model_seq = torch.nn.Sequential(*[
487
+ torch.nn.Linear(hidden, hidden, bias=False) for _ in range(num_layers)
488
+ ]).cuda()
489
+
490
+ for n, p in model_seq.named_parameters():
491
+ p.grad = grads[n].clone()
492
+
493
+ seq_names = [n for n, _ in model_seq.named_parameters()]
494
+ seq_params = [p for _, p in model_seq.named_parameters()]
495
+ param_groups_seq = [{
496
+ "params": seq_params,
497
+ "names": seq_names,
498
+ "use_muon": True,
499
+ "lr": 0.02,
500
+ "weight_decay": 0.01,
501
+ "momentum": 0.95,
502
+ "nesterov": True,
503
+ "ns_steps": 5,
504
+ "none_grad": False,
505
+ }]
506
+ optim_seq = Muon(params=param_groups_seq)
507
+ optim_seq.step()
508
+
509
+ # Correctness: parallel must match sequential exactly
510
+ for (n_par, p_par), (n_seq, p_seq) in zip(model.named_parameters(),
511
+ model_seq.named_parameters()):
512
+ par_data = p_par.data
513
+ if isinstance(par_data, DTensor):
514
+ par_data = par_data.full_tensor()
515
+ torch.testing.assert_close(par_data, p_seq.data, atol=0, rtol=0)
516
 
517
  set_ns_compile(True)
518
  logger.info(
test/test_muon_moe.py CHANGED
@@ -407,14 +407,16 @@ def test_parallel_muon_moe_uneven_shard(init_dist, uneven_dim):
407
  def test_pp_dp_replicate_moe_no_deadlock(init_dist):
408
  """Regression: PP-like MoE setup where different stages have different
409
  parameter types must not deadlock in construct_shard_mesh.
 
410
 
411
  Simulates PP=2 with dp_replicate=2, dp_shard=2. Stage 0 has only
412
- non-expert 2D DTensor params; stage 1 has non-expert 2D DTensor params
413
  plus 3D expert plain-tensor params. This mirrors real PP+MoE where
414
  expert layers exist only in certain stages.
415
  """
416
  from optimizer.distributed.utils import _ranks_to_dist_cache
417
  from optimizer.newton_schulz import set_ns_compile
 
418
 
419
  rank = dist.get_rank()
420
  world_size = dist.get_world_size()
@@ -435,33 +437,57 @@ def test_pp_dp_replicate_moe_no_deadlock(init_dist):
435
  stage_mesh = full_mesh["dp_replicate", "dp_shard"]
436
  pp_rank = full_mesh.get_local_rank("pp")
437
 
 
 
 
 
438
  torch.manual_seed(42 + pp_rank)
439
 
440
- placements = [Replicate(), Shard(0)]
441
- num_experts = 4
 
 
442
 
443
- muon_params = []
444
- muon_names = []
 
445
 
446
- # Non-expert 2D DTensor params (both stages, different counts)
447
- num_dense = 2 if pp_rank == 0 else 3
448
- for i in range(num_dense):
449
- full = torch.randn(32, 64, device="cuda")
450
- dt = distribute_tensor(full, stage_mesh, placements)
451
- p = torch.nn.Parameter(dt)
452
- p.grad = distribute_tensor(torch.randn(32, 64, device="cuda"),
453
- stage_mesh, placements)
454
- muon_params.append(p)
455
- muon_names.append(f"stage{pp_rank}.layers.{i}.weight")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
- # Stage 1 only: 3D expert plain-tensor params
458
  if pp_rank == 1:
459
- full = torch.randn(num_experts, 32, 64, device="cuda")
460
- p = torch.nn.Parameter(full)
461
- p.grad = torch.randn(num_experts, 32, 64, device="cuda")
462
- muon_params.append(p)
463
- muon_names.append(
464
- f"stage{pp_rank}.layers.{num_dense}.experts.w1.weight")
465
 
466
  param_groups = [{
467
  "params": muon_params,
@@ -482,14 +508,53 @@ def test_pp_dp_replicate_moe_no_deadlock(init_dist):
482
  expert_keys=["experts"])
483
  optim.step()
484
 
485
- # Second step to verify cached path
486
- for p in muon_params:
487
- if isinstance(p.data, DTensor):
488
- p.grad = distribute_tensor(torch.randn(32, 64, device="cuda"),
489
- stage_mesh, placements)
490
- else:
491
- p.grad = torch.randn_like(p.data)
492
- optim.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
  set_ns_compile(True)
495
  logger.info(
 
407
  def test_pp_dp_replicate_moe_no_deadlock(init_dist):
408
  """Regression: PP-like MoE setup where different stages have different
409
  parameter types must not deadlock in construct_shard_mesh.
410
+ Also verifies correctness (atol=0, rtol=0) against sequential baseline.
411
 
412
  Simulates PP=2 with dp_replicate=2, dp_shard=2. Stage 0 has only
413
+ non-expert 2D FSDP-sharded params; stage 1 has 2D FSDP-sharded params
414
  plus 3D expert plain-tensor params. This mirrors real PP+MoE where
415
  expert layers exist only in certain stages.
416
  """
417
  from optimizer.distributed.utils import _ranks_to_dist_cache
418
  from optimizer.newton_schulz import set_ns_compile
419
+ from torch.distributed.fsdp import fully_shard
420
 
421
  rank = dist.get_rank()
422
  world_size = dist.get_world_size()
 
437
  stage_mesh = full_mesh["dp_replicate", "dp_shard"]
438
  pp_rank = full_mesh.get_local_rank("pp")
439
 
440
+ num_dense = 2 if pp_rank == 0 else 3
441
+ num_experts = 4
442
+ hidden = 64
443
+
444
  torch.manual_seed(42 + pp_rank)
445
 
446
+ # Create model with dense layers (+ expert param for stage 1)
447
+ model = torch.nn.Sequential(*[
448
+ torch.nn.Linear(hidden, hidden, bias=False) for _ in range(num_dense)
449
+ ]).cuda()
450
 
451
+ # Save init state and grads for sequential baseline
452
+ init_state = {n: p.data.clone() for n, p in model.named_parameters()}
453
+ dense_grads = {n: torch.randn_like(p) for n, p in model.named_parameters()}
454
 
455
+ # Expert param (stage 1 only, plain tensor — not FSDP-sharded)
456
+ expert_data = None
457
+ expert_grad = None
458
+ if pp_rank == 1:
459
+ expert_data = torch.randn(num_experts, hidden, hidden, device="cuda")
460
+ expert_grad = torch.randn(num_experts, hidden, hidden, device="cuda")
461
+
462
+ # Apply FSDP to dense layers
463
+ for layer in model:
464
+ fully_shard(layer, mesh=stage_mesh)
465
+ fully_shard(model, mesh=stage_mesh)
466
+ model.reshard()
467
+
468
+ # Apply dense grads with DTensor redistribution
469
+ for n, p in model.named_parameters():
470
+ g = dense_grads[n]
471
+ if isinstance(p.data, DTensor):
472
+ ug = DTensor.from_local(
473
+ g,
474
+ device_mesh=p.data.device_mesh,
475
+ placements=[Replicate()] * p.data.device_mesh.ndim,
476
+ )
477
+ p.grad = ug.redistribute(device_mesh=p.data.device_mesh,
478
+ placements=p.data.placements)
479
+ else:
480
+ p.grad = g
481
+
482
+ # Build param groups: dense (FSDP DTensors) + expert (plain tensor)
483
+ muon_names = [n for n, _ in model.named_parameters()]
484
+ muon_params = list(model.parameters())
485
 
 
486
  if pp_rank == 1:
487
+ expert_p = torch.nn.Parameter(expert_data.clone())
488
+ expert_p.grad = expert_grad.clone()
489
+ muon_params.append(expert_p)
490
+ muon_names.append("experts.w1.weight")
 
 
491
 
492
  param_groups = [{
493
  "params": muon_params,
 
508
  expert_keys=["experts"])
509
  optim.step()
510
 
511
+ # Sequential baseline
512
+ torch.manual_seed(42 + pp_rank)
513
+ model_seq = torch.nn.Sequential(*[
514
+ torch.nn.Linear(hidden, hidden, bias=False) for _ in range(num_dense)
515
+ ]).cuda()
516
+
517
+ seq_names = [n for n, _ in model_seq.named_parameters()]
518
+ seq_params = list(model_seq.parameters())
519
+
520
+ for n, p in model_seq.named_parameters():
521
+ p.grad = dense_grads[n].clone()
522
+
523
+ if pp_rank == 1:
524
+ expert_p_seq = torch.nn.Parameter(expert_data.clone())
525
+ expert_p_seq.grad = expert_grad.clone()
526
+ seq_params.append(expert_p_seq)
527
+ seq_names.append("experts.w1.weight")
528
+
529
+ param_groups_seq = [{
530
+ "params": seq_params,
531
+ "names": seq_names,
532
+ "use_muon": True,
533
+ "lr": 0.02,
534
+ "weight_decay": 0.01,
535
+ "momentum": 0.95,
536
+ "nesterov": True,
537
+ "ns_steps": 5,
538
+ "none_grad": False,
539
+ }]
540
+ optim_seq = Muon(params=param_groups_seq, expert_keys=["experts"])
541
+ optim_seq.step()
542
+
543
+ # Correctness: parallel must match sequential exactly
544
+ # Dense params
545
+ for (n_par, p_par), (n_seq, p_seq) in zip(model.named_parameters(),
546
+ model_seq.named_parameters()):
547
+ par_data = p_par.data
548
+ if isinstance(par_data, DTensor):
549
+ par_data = par_data.full_tensor()
550
+ torch.testing.assert_close(par_data, p_seq.data, atol=0, rtol=0)
551
+
552
+ # Expert params (stage 1 only)
553
+ if pp_rank == 1:
554
+ torch.testing.assert_close(muon_params[-1].data,
555
+ seq_params[-1].data,
556
+ atol=0,
557
+ rtol=0)
558
 
559
  set_ns_compile(True)
560
  logger.info(