Fix deadlock in construct_shard_mesh with PP + dp_replicate > 1
Browse filesWhen Pipeline Parallelism is combined with dp_replicate > 1, different PP
stages own different parameters and call dist.new_group() in different
orders, causing a collective mismatch deadlock. Fix by using
use_local_synchronization=True so only ranks within the same group need
to coordinate, and skip creating groups for shard meshes the current rank
doesn't belong to.
[skip-build]
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
torch-ext/optimizer/distributed/utils.py
CHANGED
|
@@ -214,15 +214,16 @@ def construct_shard_mesh(
|
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
|
|
|
| 217 |
if (my_rank == sm).any().item():
|
| 218 |
+
key = _cache_key(sm)
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
+
if key not in _ranks_to_dist_cache:
|
| 222 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 223 |
+
use_local_synchronization=True)
|
| 224 |
+
_ranks_to_dist_cache[key] = (
|
| 225 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 226 |
+
pg,
|
| 227 |
+
)
|
| 228 |
|
| 229 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|