|
from time import sleep, time |
|
from ditk import logging |
|
from ding.framework import task |
|
from ding.utils.lock_helper import LockContext, LockContextType |
|
from ding.utils.design_helper import SingletonMetaclass |
|
|
|
|
|
class BarrierRuntime(metaclass=SingletonMetaclass): |
|
|
|
def __init__(self, node_id: int, max_world_size: int = 100): |
|
""" |
|
Overview: |
|
'BarrierRuntime' is a singleton class. In addition, it must be initialized before the |
|
class 'Parallel' starts MQ, otherwise the messages sent by other nodes may be lost after |
|
the detection is completed. We don't have a message retransmission mechanism, and losing |
|
a message means deadlock. |
|
Arguments: |
|
- node_id (int): Process ID. |
|
- max_world_size (int, optional): The maximum total number of processes that can be |
|
synchronized, the defalut value is 100. |
|
""" |
|
self.node_id = node_id |
|
self._has_detected = False |
|
self._range_len = len(str(max_world_size)) + 1 |
|
|
|
self._barrier_epoch = 0 |
|
self._barrier_recv_peers_buff = dict() |
|
self._barrier_recv_peers = dict() |
|
self._barrier_ack_peers = [] |
|
self._barrier_lock = LockContext(LockContextType.THREAD_LOCK) |
|
|
|
self.mq_type = task.router.mq_type |
|
self._connected_peers = dict() |
|
self._connected_peers_lock = LockContext(LockContextType.THREAD_LOCK) |
|
self._keep_alive_daemon = False |
|
|
|
self._event_name_detect = "b_det" |
|
self.event_name_req = "b_req" |
|
self.event_name_ack = "b_ack" |
|
|
|
def _alive_msg_handler(self, peer_id): |
|
with self._connected_peers_lock: |
|
self._connected_peers[peer_id] = time() |
|
|
|
def _add_barrier_req(self, msg): |
|
peer, epoch = self._unpickle_barrier_tag(msg) |
|
logging.debug("Node:[{}] recv barrier request from node:{}, epoch:{}".format(self.node_id, peer, epoch)) |
|
with self._barrier_lock: |
|
if peer not in self._barrier_recv_peers: |
|
self._barrier_recv_peers[peer] = [] |
|
self._barrier_recv_peers[peer].append(epoch) |
|
|
|
def _add_barrier_ack(self, peer): |
|
logging.debug("Node:[{}] recv barrier ack from node:{}".format(self.node_id, peer)) |
|
with self._barrier_lock: |
|
self._barrier_ack_peers.append(peer) |
|
|
|
def _unpickle_barrier_tag(self, msg): |
|
return msg % self._range_len, msg // self._range_len |
|
|
|
def pickle_barrier_tag(self): |
|
return int(self._barrier_epoch * self._range_len + self.node_id) |
|
|
|
def reset_all_peers(self): |
|
with self._barrier_lock: |
|
for peer, q in self._barrier_recv_peers.items(): |
|
if len(q) != 0: |
|
assert q.pop(0) == self._barrier_epoch |
|
self._barrier_ack_peers = [] |
|
self._barrier_epoch += 1 |
|
|
|
def get_recv_num(self): |
|
count = 0 |
|
with self._barrier_lock: |
|
if len(self._barrier_recv_peers) > 0: |
|
for _, q in self._barrier_recv_peers.items(): |
|
if len(q) > 0 and q[0] == self._barrier_epoch: |
|
count += 1 |
|
return count |
|
|
|
def get_ack_num(self): |
|
with self._barrier_lock: |
|
return len(self._barrier_ack_peers) |
|
|
|
def detect_alive(self, expected, timeout): |
|
|
|
|
|
|
|
assert task._running |
|
task.on(self._event_name_detect, self._alive_msg_handler) |
|
task.on(self.event_name_req, self._add_barrier_req) |
|
task.on(self.event_name_ack, self._add_barrier_ack) |
|
start = time() |
|
while True: |
|
sleep(0.1) |
|
task.emit(self._event_name_detect, self.node_id, only_remote=True) |
|
|
|
|
|
if self._has_detected: |
|
break |
|
with self._connected_peers_lock: |
|
if len(self._connected_peers) == expected: |
|
self._has_detected = True |
|
|
|
if time() - start > timeout: |
|
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) |
|
|
|
task.off(self._event_name_detect) |
|
logging.info( |
|
"Barrier detect node done, node-[{}] has connected with {} active nodes!".format(self.node_id, expected) |
|
) |
|
|
|
|
|
class BarrierContext: |
|
|
|
def __init__(self, runtime: BarrierRuntime, detect_timeout, expected_peer_num: int = 0): |
|
self._runtime = runtime |
|
self._expected_peer_num = expected_peer_num |
|
self._timeout = detect_timeout |
|
|
|
def __enter__(self): |
|
if not self._runtime._has_detected: |
|
self._runtime.detect_alive(self._expected_peer_num, self._timeout) |
|
|
|
def __exit__(self, exc_type, exc_value, tb): |
|
if exc_type is not None: |
|
import traceback |
|
traceback.print_exception(exc_type, exc_value, tb) |
|
self._runtime.reset_all_peers() |
|
|
|
|
|
class Barrier: |
|
|
|
def __init__(self, attch_from_nums: int, timeout: int = 60): |
|
""" |
|
Overview: |
|
Barrier() is a middleware for debug or profiling. It can synchronize the task step of each |
|
process within the scope of all visible processes. When using Barrier(), you need to pay |
|
attention to the following points: |
|
|
|
1. All processes must call the same number of Barrier(), otherwise a deadlock occurs. |
|
|
|
2. 'attch_from_nums' is a very important variable, This value indicates the number of times |
|
the current process will be attached to by other processes (the number of connections |
|
established). |
|
For example: |
|
Node0: address: 127.0.0.1:12345, attach_to = [] |
|
Node1: address: 127.0.0.1:12346, attach_to = ["tcp://127.0.0.1:12345"] |
|
For Node0, the 'attch_from_nums' value is 1. (It will be acttched by Node1) |
|
For Node1, the 'attch_from_nums' value is 0. (No one will attach to Node1) |
|
Please note that this value must be given correctly, otherwise, for a node whose 'attach_to' |
|
list is empty, it cannot perceive how many processes will establish connections with it, |
|
resulting in any form of synchronization cannot be performed. |
|
|
|
3. Barrier() is thread-safe, but it is not recommended to use barrier in multithreading. You need |
|
to carefully calculate the number of times each thread calls Barrier() to avoid deadlock. |
|
|
|
4. In normal training tasks, please do not use Barrier(), which will force the step synchronization |
|
between each process, so it will greatly damage the training efficiency. In addition, if your |
|
training task has dynamic processes, do not use Barrier() to prevent deadlock. |
|
|
|
Arguments: |
|
- attch_from_nums (int): [description] |
|
- timeout (int, optional): The timeout for successful detection of 'expected_peer_num' |
|
number of nodes, the default value is 60 seconds. |
|
""" |
|
self.node_id = task.router.node_id |
|
self.timeout = timeout |
|
self._runtime: BarrierRuntime = task.router.barrier_runtime |
|
self._barrier_peers_nums = task.get_attch_to_len() + attch_from_nums |
|
|
|
logging.info( |
|
"Node:[{}], attach to num is:{}, attach from num is:{}".format( |
|
self.node_id, task.get_attch_to_len(), attch_from_nums |
|
) |
|
) |
|
|
|
def __call__(self, ctx): |
|
self._wait_barrier(ctx) |
|
yield |
|
self._wait_barrier(ctx) |
|
|
|
def _wait_barrier(self, ctx): |
|
self_ready = False |
|
with BarrierContext(self._runtime, self.timeout, self._barrier_peers_nums): |
|
logging.debug("Node:[{}] enter barrier".format(self.node_id)) |
|
|
|
task.emit(self._runtime.event_name_req, self._runtime.pickle_barrier_tag(), only_remote=True) |
|
logging.debug("Node:[{}] sended barrier request".format(self.node_id)) |
|
|
|
|
|
|
|
|
|
if self._runtime.get_recv_num() == self._barrier_peers_nums: |
|
self_ready = True |
|
|
|
|
|
|
|
|
|
|
|
start = time() |
|
if not self_ready: |
|
while True: |
|
if time() - start > self.timeout: |
|
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) |
|
|
|
if self._runtime.get_recv_num() != self._barrier_peers_nums: |
|
sleep(0.1) |
|
else: |
|
break |
|
|
|
|
|
task.emit(self._runtime.event_name_ack, self.node_id, only_remote=True) |
|
logging.debug("Node:[{}] sended barrier ack".format(self.node_id)) |
|
|
|
|
|
start = time() |
|
while True: |
|
if time() - start > self.timeout: |
|
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) |
|
|
|
if self._runtime.get_ack_num() != self._barrier_peers_nums: |
|
sleep(0.1) |
|
else: |
|
break |
|
|
|
logging.info("Node-[{}] env_step:[{}] barrier finish".format(self.node_id, ctx.env_step)) |
|
|