|
from contextlib import contextmanager |
|
from typing import cast |
|
import logging |
|
from . import api |
|
from . import TensorPipeAgent |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
@contextmanager |
|
def _group_membership_management(store, name, is_join): |
|
token_key = "RpcGroupManagementToken" |
|
join_or_leave = "join" if is_join else "leave" |
|
my_token = f"Token_for_{name}_{join_or_leave}" |
|
while True: |
|
|
|
returned = store.compare_set(token_key, "", my_token).decode() |
|
if returned == my_token: |
|
|
|
yield |
|
|
|
|
|
store.set(token_key, "") |
|
|
|
store.set(my_token, "Done") |
|
break |
|
else: |
|
|
|
try: |
|
store.wait([returned]) |
|
except RuntimeError: |
|
logger.error(f"Group membership token {my_token} timed out waiting for {returned} to be released.") |
|
raise |
|
|
|
def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): |
|
agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) |
|
ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join) |
|
return ret |
|
|