Remove correctness check from PP tests, focus on deadlock detection [skip-build]
Browse filesCorrectness of Replicate+Shard is already verified by existing hsdp tests.
These PP tests specifically validate that asymmetric construct_shard_mesh
calls across PP stages don't deadlock.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- test/test_muon.py +4 -45
- test/test_muon_moe.py +5 -42
test/test_muon.py
CHANGED
|
@@ -435,17 +435,13 @@ def test_pp_dp_replicate_no_deadlock(init_dist):
|
|
| 435 |
|
| 436 |
muon_params = []
|
| 437 |
muon_names = []
|
| 438 |
-
full_params_snapshot = []
|
| 439 |
-
full_grads = []
|
| 440 |
|
| 441 |
for i in range(num_params):
|
| 442 |
full = torch.randn(32, 64, device="cuda")
|
| 443 |
-
full_params_snapshot.append(full.clone())
|
| 444 |
dt = distribute_tensor(full, stage_mesh, placements)
|
| 445 |
p = torch.nn.Parameter(dt)
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
p.grad = distribute_tensor(grad_full, stage_mesh, placements)
|
| 449 |
muon_params.append(p)
|
| 450 |
muon_names.append(f"stage{pp_rank}.layer.{i}.weight")
|
| 451 |
|
|
@@ -467,47 +463,10 @@ def test_pp_dp_replicate_no_deadlock(init_dist):
|
|
| 467 |
|
| 468 |
# Second step to verify cached path
|
| 469 |
for p in muon_params:
|
| 470 |
-
|
| 471 |
-
|
| 472 |
optim.step()
|
| 473 |
|
| 474 |
-
# Correctness: compare against sequential baseline
|
| 475 |
-
seq_params = []
|
| 476 |
-
for fp, fg in zip(full_params_snapshot, full_grads):
|
| 477 |
-
p = torch.nn.Parameter(fp.clone())
|
| 478 |
-
p.grad = fg.clone()
|
| 479 |
-
seq_params.append(p)
|
| 480 |
-
|
| 481 |
-
param_groups_seq = [{
|
| 482 |
-
"params":
|
| 483 |
-
seq_params,
|
| 484 |
-
"names":
|
| 485 |
-
[f"stage{pp_rank}.layer.{i}.weight" for i in range(num_params)],
|
| 486 |
-
"use_muon":
|
| 487 |
-
True,
|
| 488 |
-
"lr":
|
| 489 |
-
0.02,
|
| 490 |
-
"weight_decay":
|
| 491 |
-
0.01,
|
| 492 |
-
"momentum":
|
| 493 |
-
0.95,
|
| 494 |
-
"nesterov":
|
| 495 |
-
True,
|
| 496 |
-
"ns_steps":
|
| 497 |
-
5,
|
| 498 |
-
"none_grad":
|
| 499 |
-
False,
|
| 500 |
-
}]
|
| 501 |
-
optim_seq = Muon(params=param_groups_seq)
|
| 502 |
-
optim_seq.step()
|
| 503 |
-
|
| 504 |
-
for i in range(num_params):
|
| 505 |
-
par_full = muon_params[i].data.full_tensor()
|
| 506 |
-
torch.testing.assert_close(par_full,
|
| 507 |
-
seq_params[i].data,
|
| 508 |
-
atol=0,
|
| 509 |
-
rtol=0)
|
| 510 |
-
|
| 511 |
set_ns_compile(True)
|
| 512 |
logger.info(
|
| 513 |
"test_pp_dp_replicate_no_deadlock PASSED (rank %d, pp_rank %d)", rank,
|
|
|
|
| 435 |
|
| 436 |
muon_params = []
|
| 437 |
muon_names = []
|
|
|
|
|
|
|
| 438 |
|
| 439 |
for i in range(num_params):
|
| 440 |
full = torch.randn(32, 64, device="cuda")
|
|
|
|
| 441 |
dt = distribute_tensor(full, stage_mesh, placements)
|
| 442 |
p = torch.nn.Parameter(dt)
|
| 443 |
+
p.grad = distribute_tensor(torch.randn(32, 64, device="cuda"),
|
| 444 |
+
stage_mesh, placements)
|
|
|
|
| 445 |
muon_params.append(p)
|
| 446 |
muon_names.append(f"stage{pp_rank}.layer.{i}.weight")
|
| 447 |
|
|
|
|
| 463 |
|
| 464 |
# Second step to verify cached path
|
| 465 |
for p in muon_params:
|
| 466 |
+
p.grad = distribute_tensor(torch.randn(32, 64, device="cuda"),
|
| 467 |
+
stage_mesh, placements)
|
| 468 |
optim.step()
|
| 469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
set_ns_compile(True)
|
| 471 |
logger.info(
|
| 472 |
"test_pp_dp_replicate_no_deadlock PASSED (rank %d, pp_rank %d)", rank,
|
test/test_muon_moe.py
CHANGED
|
@@ -442,30 +442,23 @@ def test_pp_dp_replicate_moe_no_deadlock(init_dist):
|
|
| 442 |
|
| 443 |
muon_params = []
|
| 444 |
muon_names = []
|
| 445 |
-
full_params_snapshot = []
|
| 446 |
-
full_grads = []
|
| 447 |
|
| 448 |
# Non-expert 2D DTensor params (both stages, different counts)
|
| 449 |
num_dense = 2 if pp_rank == 0 else 3
|
| 450 |
for i in range(num_dense):
|
| 451 |
full = torch.randn(32, 64, device="cuda")
|
| 452 |
-
full_params_snapshot.append(full.clone())
|
| 453 |
dt = distribute_tensor(full, stage_mesh, placements)
|
| 454 |
p = torch.nn.Parameter(dt)
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
p.grad = distribute_tensor(g, stage_mesh, placements)
|
| 458 |
muon_params.append(p)
|
| 459 |
muon_names.append(f"stage{pp_rank}.layers.{i}.weight")
|
| 460 |
|
| 461 |
# Stage 1 only: 3D expert plain-tensor params
|
| 462 |
if pp_rank == 1:
|
| 463 |
full = torch.randn(num_experts, 32, 64, device="cuda")
|
| 464 |
-
full_params_snapshot.append(full.clone())
|
| 465 |
p = torch.nn.Parameter(full)
|
| 466 |
-
|
| 467 |
-
full_grads.append(g.clone())
|
| 468 |
-
p.grad = g
|
| 469 |
muon_params.append(p)
|
| 470 |
muon_names.append(
|
| 471 |
f"stage{pp_rank}.layers.{num_dense}.experts.w1.weight")
|
|
@@ -492,42 +485,12 @@ def test_pp_dp_replicate_moe_no_deadlock(init_dist):
|
|
| 492 |
# Second step to verify cached path
|
| 493 |
for p in muon_params:
|
| 494 |
if isinstance(p.data, DTensor):
|
| 495 |
-
|
| 496 |
-
|
| 497 |
else:
|
| 498 |
p.grad = torch.randn_like(p.data)
|
| 499 |
optim.step()
|
| 500 |
|
| 501 |
-
# Correctness: compare against sequential baseline
|
| 502 |
-
seq_params = []
|
| 503 |
-
for fp in full_params_snapshot:
|
| 504 |
-
seq_params.append(torch.nn.Parameter(fp.clone()))
|
| 505 |
-
for p, g in zip(seq_params, full_grads):
|
| 506 |
-
p.grad = g.clone()
|
| 507 |
-
|
| 508 |
-
param_groups_seq = [{
|
| 509 |
-
"params": seq_params,
|
| 510 |
-
"names": list(muon_names),
|
| 511 |
-
"use_muon": True,
|
| 512 |
-
"lr": 0.02,
|
| 513 |
-
"weight_decay": 0.01,
|
| 514 |
-
"momentum": 0.95,
|
| 515 |
-
"nesterov": True,
|
| 516 |
-
"ns_steps": 5,
|
| 517 |
-
"none_grad": False,
|
| 518 |
-
}]
|
| 519 |
-
optim_seq = Muon(params=param_groups_seq, expert_keys=["experts"])
|
| 520 |
-
optim_seq.step()
|
| 521 |
-
|
| 522 |
-
for i in range(len(muon_params)):
|
| 523 |
-
par_data = muon_params[i].data
|
| 524 |
-
if isinstance(par_data, DTensor):
|
| 525 |
-
par_data = par_data.full_tensor()
|
| 526 |
-
torch.testing.assert_close(par_data,
|
| 527 |
-
seq_params[i].data,
|
| 528 |
-
atol=0,
|
| 529 |
-
rtol=0)
|
| 530 |
-
|
| 531 |
set_ns_compile(True)
|
| 532 |
logger.info(
|
| 533 |
"test_pp_dp_replicate_moe_no_deadlock PASSED (rank %d, pp_rank %d)",
|
|
|
|
| 442 |
|
| 443 |
muon_params = []
|
| 444 |
muon_names = []
|
|
|
|
|
|
|
| 445 |
|
| 446 |
# Non-expert 2D DTensor params (both stages, different counts)
|
| 447 |
num_dense = 2 if pp_rank == 0 else 3
|
| 448 |
for i in range(num_dense):
|
| 449 |
full = torch.randn(32, 64, device="cuda")
|
|
|
|
| 450 |
dt = distribute_tensor(full, stage_mesh, placements)
|
| 451 |
p = torch.nn.Parameter(dt)
|
| 452 |
+
p.grad = distribute_tensor(torch.randn(32, 64, device="cuda"),
|
| 453 |
+
stage_mesh, placements)
|
|
|
|
| 454 |
muon_params.append(p)
|
| 455 |
muon_names.append(f"stage{pp_rank}.layers.{i}.weight")
|
| 456 |
|
| 457 |
# Stage 1 only: 3D expert plain-tensor params
|
| 458 |
if pp_rank == 1:
|
| 459 |
full = torch.randn(num_experts, 32, 64, device="cuda")
|
|
|
|
| 460 |
p = torch.nn.Parameter(full)
|
| 461 |
+
p.grad = torch.randn(num_experts, 32, 64, device="cuda")
|
|
|
|
|
|
|
| 462 |
muon_params.append(p)
|
| 463 |
muon_names.append(
|
| 464 |
f"stage{pp_rank}.layers.{num_dense}.experts.w1.weight")
|
|
|
|
| 485 |
# Second step to verify cached path
|
| 486 |
for p in muon_params:
|
| 487 |
if isinstance(p.data, DTensor):
|
| 488 |
+
p.grad = distribute_tensor(torch.randn(32, 64, device="cuda"),
|
| 489 |
+
stage_mesh, placements)
|
| 490 |
else:
|
| 491 |
p.grad = torch.randn_like(p.data)
|
| 492 |
optim.step()
|
| 493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
set_ns_compile(True)
|
| 495 |
logger.info(
|
| 496 |
"test_pp_dp_replicate_moe_no_deadlock PASSED (rank %d, pp_rank %d)",
|