Update comment to reflect use_local_synchronization behavior [skip-build]
Browse files
torch-ext/optimizer/distributed/utils.py
CHANGED
|
@@ -207,8 +207,10 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
|
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
+
# Each rank only creates the group it belongs to, using
|
| 211 |
+
# use_local_synchronization=True so that only group members need to
|
| 212 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 213 |
+
# construct_shard_mesh for different parameters.
|
| 214 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 215 |
return (*t.shape, *t.flatten().tolist())
|
| 216 |
|