Kernels
wyldecat Claude Opus 4.6 commited on
Commit
3f5cf49
·
1 Parent(s): da7e5da

Update comment to reflect use_local_synchronization behavior [skip-build]

Browse files
torch-ext/optimizer/distributed/utils.py CHANGED
@@ -207,8 +207,10 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
 
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
+ # Each rank only creates the group it belongs to, using
211
+ # use_local_synchronization=True so that only group members need to
212
+ # coordinate. This avoids deadlocks when different PP stages call
213
+ # construct_shard_mesh for different parameters.
214
  def _cache_key(t: torch.Tensor) -> tuple:
215
  return (*t.shape, *t.flatten().tolist())
216