Kernels
wyldecat Claude Opus 4.6 commited on
Commit
da7e5da
·
1 Parent(s): 96b287c

Fix deadlock in construct_shard_mesh with PP + dp_replicate > 1

Browse files

When Pipeline Parallelism is combined with dp_replicate > 1, different PP
stages own different parameters and call dist.new_group() in different
orders, causing a collective mismatch deadlock. Fix by using
use_local_synchronization=True so only ranks within the same group need
to coordinate, and skip creating groups for shard meshes the current rank
doesn't belong to.

[skip-build]

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

torch-ext/optimizer/distributed/utils.py CHANGED
@@ -214,15 +214,16 @@ def construct_shard_mesh(
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
214
 
215
  my_key = None
216
  for sm in shard_meshes:
 
217
  if (my_rank == sm).any().item():
218
+ key = _cache_key(sm)
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
+ if key not in _ranks_to_dist_cache:
222
+ pg = dist.new_group(sm.flatten().tolist(),
223
+ use_local_synchronization=True)
224
+ _ranks_to_dist_cache[key] = (
225
+ DeviceMesh(device_type="cuda", mesh=sm),
226
+ pg,
227
+ )
228
 
229
  return (*_ranks_to_dist_cache[my_key], shard_placements)