|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
the class of WorkerGroup
|
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
import signal
|
|
|
import threading
|
|
|
import time
|
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
|
|
from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
|
|
|
|
|
|
|
|
|
class ResourcePool:
|
|
|
"""The resource pool with meta info such as world_size."""
|
|
|
|
|
|
def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8) -> None:
|
|
|
if process_on_nodes is None:
|
|
|
process_on_nodes = []
|
|
|
self._store = process_on_nodes
|
|
|
self.max_colocate_count = max_colocate_count
|
|
|
self.n_gpus_per_node = n_gpus_per_node
|
|
|
|
|
|
def add_node(self, process_count):
|
|
|
self._store.append(process_count)
|
|
|
|
|
|
@property
|
|
|
def world_size(self):
|
|
|
return sum(self._store)
|
|
|
|
|
|
def __call__(self) -> Any:
|
|
|
return self._store
|
|
|
|
|
|
@property
|
|
|
def store(self):
|
|
|
return self._store
|
|
|
|
|
|
def local_world_size_list(self) -> List[int]:
|
|
|
nested_local_world_size_list = [[local_world_size for _ in range(local_world_size)] for local_world_size in self._store]
|
|
|
return [item for row in nested_local_world_size_list for item in row]
|
|
|
|
|
|
def local_rank_list(self) -> List[int]:
|
|
|
nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]
|
|
|
return [item for row in nested_local_rank_list for item in row]
|
|
|
|
|
|
|
|
|
class ClassWithInitArgs:
|
|
|
"""
|
|
|
This class stores a class constructor and the args/kwargs to construct the class.
|
|
|
It is used to instantiate the remote class.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, cls, *args, **kwargs) -> None:
|
|
|
self.cls = cls
|
|
|
self.args = args
|
|
|
self.kwargs = kwargs
|
|
|
|
|
|
self.fused_worker_used = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self) -> Any:
|
|
|
return self.cls(*self.args, **self.kwargs)
|
|
|
|
|
|
|
|
|
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
|
|
|
import time
|
|
|
|
|
|
while True:
|
|
|
for worker in workers:
|
|
|
if not is_alive(worker):
|
|
|
logging.warning(f"worker {worker} is not alive sending signal to main thread")
|
|
|
signal.raise_signal(signal.SIGABRT)
|
|
|
time.sleep(gap_time)
|
|
|
|
|
|
|
|
|
class WorkerGroup:
|
|
|
"""A group of workers"""
|
|
|
|
|
|
fused_worker_execute_fn_name = "_fuw_execute"
|
|
|
|
|
|
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
|
|
|
self._is_init_with_detached_workers = resource_pool is None
|
|
|
|
|
|
self.fused_worker_used = False
|
|
|
|
|
|
if resource_pool is not None:
|
|
|
|
|
|
self._procecss_dispatch_config = resource_pool()
|
|
|
else:
|
|
|
self._procecss_dispatch_config = None
|
|
|
|
|
|
self._workers = []
|
|
|
self._worker_names = []
|
|
|
|
|
|
self._master_addr = None
|
|
|
self._master_port = None
|
|
|
|
|
|
self._checker_thread: threading.Thread = None
|
|
|
|
|
|
def _is_worker_alive(self, worker):
|
|
|
raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.")
|
|
|
|
|
|
def _block_until_all_workers_alive(self) -> None:
|
|
|
while True:
|
|
|
all_state = [self._is_worker_alive(worker) for worker in self._workers]
|
|
|
if False in all_state:
|
|
|
time.sleep(1)
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
|
|
|
|
|
|
self._block_until_all_workers_alive()
|
|
|
|
|
|
self._checker_thread = threading.Thread(target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds))
|
|
|
self._checker_thread.start()
|
|
|
|
|
|
@property
|
|
|
def world_size(self):
|
|
|
return len(self._workers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _bind_worker_method(self, user_defined_cls, func_generator):
|
|
|
"""
|
|
|
Bind the worker method to the WorkerGroup
|
|
|
"""
|
|
|
|
|
|
method_names = []
|
|
|
for method_name in dir(user_defined_cls):
|
|
|
try:
|
|
|
method = getattr(user_defined_cls, method_name)
|
|
|
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
|
|
|
except Exception:
|
|
|
|
|
|
continue
|
|
|
|
|
|
if hasattr(method, MAGIC_ATTR):
|
|
|
|
|
|
attribute = getattr(method, MAGIC_ATTR)
|
|
|
assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}"
|
|
|
assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key"
|
|
|
|
|
|
dispatch_mode = attribute["dispatch_mode"]
|
|
|
execute_mode = attribute["execute_mode"]
|
|
|
blocking = attribute["blocking"]
|
|
|
|
|
|
|
|
|
if isinstance(dispatch_mode, Dispatch):
|
|
|
|
|
|
fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
|
|
|
dispatch_fn = fn["dispatch_fn"]
|
|
|
collect_fn = fn["collect_fn"]
|
|
|
else:
|
|
|
assert isinstance(dispatch_mode, dict)
|
|
|
assert "dispatch_fn" in dispatch_mode
|
|
|
assert "collect_fn" in dispatch_mode
|
|
|
dispatch_fn = dispatch_mode["dispatch_fn"]
|
|
|
collect_fn = dispatch_mode["collect_fn"]
|
|
|
|
|
|
|
|
|
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
|
|
|
wg_execute_fn_name = execute_mode["execute_fn_name"]
|
|
|
|
|
|
|
|
|
try:
|
|
|
execute_fn = getattr(self, wg_execute_fn_name)
|
|
|
assert callable(execute_fn), "execute_fn must be callable"
|
|
|
except Exception:
|
|
|
print(f"execute_fn {wg_execute_fn_name} is invalid")
|
|
|
raise
|
|
|
|
|
|
|
|
|
func = func_generator(
|
|
|
self,
|
|
|
method_name,
|
|
|
dispatch_fn=dispatch_fn,
|
|
|
collect_fn=collect_fn,
|
|
|
execute_fn=execute_fn,
|
|
|
blocking=blocking,
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
setattr(self, method_name, func)
|
|
|
method_names.append(method_name)
|
|
|
except Exception as e:
|
|
|
raise ValueError(f"Fail to set method_name {method_name}") from e
|
|
|
|
|
|
return method_names
|
|
|
|