zjowowen's picture
init space
079c32c
raw
history blame
10.3 kB
from time import sleep, time
from ditk import logging
from ding.framework import task
from ding.utils.lock_helper import LockContext, LockContextType
from ding.utils.design_helper import SingletonMetaclass
class BarrierRuntime(metaclass=SingletonMetaclass):
def __init__(self, node_id: int, max_world_size: int = 100):
"""
Overview:
'BarrierRuntime' is a singleton class. In addition, it must be initialized before the
class 'Parallel' starts MQ, otherwise the messages sent by other nodes may be lost after
the detection is completed. We don't have a message retransmission mechanism, and losing
a message means deadlock.
Arguments:
- node_id (int): Process ID.
- max_world_size (int, optional): The maximum total number of processes that can be
synchronized, the defalut value is 100.
"""
self.node_id = node_id
self._has_detected = False
self._range_len = len(str(max_world_size)) + 1
self._barrier_epoch = 0
self._barrier_recv_peers_buff = dict()
self._barrier_recv_peers = dict()
self._barrier_ack_peers = []
self._barrier_lock = LockContext(LockContextType.THREAD_LOCK)
self.mq_type = task.router.mq_type
self._connected_peers = dict()
self._connected_peers_lock = LockContext(LockContextType.THREAD_LOCK)
self._keep_alive_daemon = False
self._event_name_detect = "b_det"
self.event_name_req = "b_req"
self.event_name_ack = "b_ack"
def _alive_msg_handler(self, peer_id):
with self._connected_peers_lock:
self._connected_peers[peer_id] = time()
def _add_barrier_req(self, msg):
peer, epoch = self._unpickle_barrier_tag(msg)
logging.debug("Node:[{}] recv barrier request from node:{}, epoch:{}".format(self.node_id, peer, epoch))
with self._barrier_lock:
if peer not in self._barrier_recv_peers:
self._barrier_recv_peers[peer] = []
self._barrier_recv_peers[peer].append(epoch)
def _add_barrier_ack(self, peer):
logging.debug("Node:[{}] recv barrier ack from node:{}".format(self.node_id, peer))
with self._barrier_lock:
self._barrier_ack_peers.append(peer)
def _unpickle_barrier_tag(self, msg):
return msg % self._range_len, msg // self._range_len
def pickle_barrier_tag(self):
return int(self._barrier_epoch * self._range_len + self.node_id)
def reset_all_peers(self):
with self._barrier_lock:
for peer, q in self._barrier_recv_peers.items():
if len(q) != 0:
assert q.pop(0) == self._barrier_epoch
self._barrier_ack_peers = []
self._barrier_epoch += 1
def get_recv_num(self):
count = 0
with self._barrier_lock:
if len(self._barrier_recv_peers) > 0:
for _, q in self._barrier_recv_peers.items():
if len(q) > 0 and q[0] == self._barrier_epoch:
count += 1
return count
def get_ack_num(self):
with self._barrier_lock:
return len(self._barrier_ack_peers)
def detect_alive(self, expected, timeout):
# The barrier can only block other nodes within the visible range of the current node.
# If the 'attch_to' list of a node is empty, it does not know how many nodes will attach to him,
# so we cannot specify the effective range of a barrier in advance.
assert task._running
task.on(self._event_name_detect, self._alive_msg_handler)
task.on(self.event_name_req, self._add_barrier_req)
task.on(self.event_name_ack, self._add_barrier_ack)
start = time()
while True:
sleep(0.1)
task.emit(self._event_name_detect, self.node_id, only_remote=True)
# In case the other node has not had time to receive our detect message,
# we will send an additional round.
if self._has_detected:
break
with self._connected_peers_lock:
if len(self._connected_peers) == expected:
self._has_detected = True
if time() - start > timeout:
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id))
task.off(self._event_name_detect)
logging.info(
"Barrier detect node done, node-[{}] has connected with {} active nodes!".format(self.node_id, expected)
)
class BarrierContext:
def __init__(self, runtime: BarrierRuntime, detect_timeout, expected_peer_num: int = 0):
self._runtime = runtime
self._expected_peer_num = expected_peer_num
self._timeout = detect_timeout
def __enter__(self):
if not self._runtime._has_detected:
self._runtime.detect_alive(self._expected_peer_num, self._timeout)
def __exit__(self, exc_type, exc_value, tb):
if exc_type is not None:
import traceback
traceback.print_exception(exc_type, exc_value, tb)
self._runtime.reset_all_peers()
class Barrier:
def __init__(self, attch_from_nums: int, timeout: int = 60):
"""
Overview:
Barrier() is a middleware for debug or profiling. It can synchronize the task step of each
process within the scope of all visible processes. When using Barrier(), you need to pay
attention to the following points:
1. All processes must call the same number of Barrier(), otherwise a deadlock occurs.
2. 'attch_from_nums' is a very important variable, This value indicates the number of times
the current process will be attached to by other processes (the number of connections
established).
For example:
Node0: address: 127.0.0.1:12345, attach_to = []
Node1: address: 127.0.0.1:12346, attach_to = ["tcp://127.0.0.1:12345"]
For Node0, the 'attch_from_nums' value is 1. (It will be acttched by Node1)
For Node1, the 'attch_from_nums' value is 0. (No one will attach to Node1)
Please note that this value must be given correctly, otherwise, for a node whose 'attach_to'
list is empty, it cannot perceive how many processes will establish connections with it,
resulting in any form of synchronization cannot be performed.
3. Barrier() is thread-safe, but it is not recommended to use barrier in multithreading. You need
to carefully calculate the number of times each thread calls Barrier() to avoid deadlock.
4. In normal training tasks, please do not use Barrier(), which will force the step synchronization
between each process, so it will greatly damage the training efficiency. In addition, if your
training task has dynamic processes, do not use Barrier() to prevent deadlock.
Arguments:
- attch_from_nums (int): [description]
- timeout (int, optional): The timeout for successful detection of 'expected_peer_num'
number of nodes, the default value is 60 seconds.
"""
self.node_id = task.router.node_id
self.timeout = timeout
self._runtime: BarrierRuntime = task.router.barrier_runtime
self._barrier_peers_nums = task.get_attch_to_len() + attch_from_nums
logging.info(
"Node:[{}], attach to num is:{}, attach from num is:{}".format(
self.node_id, task.get_attch_to_len(), attch_from_nums
)
)
def __call__(self, ctx):
self._wait_barrier(ctx)
yield
self._wait_barrier(ctx)
def _wait_barrier(self, ctx):
self_ready = False
with BarrierContext(self._runtime, self.timeout, self._barrier_peers_nums):
logging.debug("Node:[{}] enter barrier".format(self.node_id))
# Step1: Notifies all the attached nodes that we have reached the barrier.
task.emit(self._runtime.event_name_req, self._runtime.pickle_barrier_tag(), only_remote=True)
logging.debug("Node:[{}] sended barrier request".format(self.node_id))
# Step2: We check the number of flags we have received.
# In the current CI design of DI-engine, there will always be a node whose 'attach_to' list is empty,
# so there will always be a node that will send ACK unconditionally, so deadlock will not occur.
if self._runtime.get_recv_num() == self._barrier_peers_nums:
self_ready = True
# Step3: Waiting for our own to be ready.
# Even if the current process has reached the barrier, we will not send an ack immediately,
# we need to wait for the slowest directly connected or indirectly connected peer to
# reach the barrier.
start = time()
if not self_ready:
while True:
if time() - start > self.timeout:
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id))
if self._runtime.get_recv_num() != self._barrier_peers_nums:
sleep(0.1)
else:
break
# Step4: Notifies all attached nodes that we are ready.
task.emit(self._runtime.event_name_ack, self.node_id, only_remote=True)
logging.debug("Node:[{}] sended barrier ack".format(self.node_id))
# Step5: Wait until all directly or indirectly connected nodes are ready.
start = time()
while True:
if time() - start > self.timeout:
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id))
if self._runtime.get_ack_num() != self._barrier_peers_nums:
sleep(0.1)
else:
break
logging.info("Node-[{}] env_step:[{}] barrier finish".format(self.node_id, ctx.env_step))