Update fast path comment to reflect current behavior [skip-build]
Browse files
torch-ext/optimizer/distributed/utils.py
CHANGED
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|