Kernels
wyldecat Claude Opus 4.6 commited on
Commit
7e33533
·
1 Parent(s): 3f5cf49

Update fast path comment to reflect current behavior [skip-build]

Browse files
torch-ext/optimizer/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache: