__all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"] import collections import enum from typing import cast, Dict, List, Set, Tuple import torch import torch.distributed as dist from ._utils import _group_membership_management, _update_group_membership from . import api from . import constants as rpc_constants __all__ = ["backend_registered", "register_backend", "construct_rpc_backend_options", "init_backend", "BackendValue", "BackendType"] BackendValue = collections.namedtuple( "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"] ) def _backend_type_repr(self): return "BackendType." + self.name _backend_type_doc = """ An enum class of available backends. PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend. Additional ones can be registered using the :func:`~torch.distributed.rpc.backend_registry.register_backend` function. """ # Create an enum type, `BackendType`, with empty members. # Can't handle Function Enum API (mypy bug #9079) BackendType = enum.Enum(value="BackendType", names=dict()) # type: ignore[misc] # Unable to assign a function a method (mypy bug #2427) BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] if BackendType.__doc__: BackendType.__doc__ = _backend_type_doc def backend_registered(backend_name): """ Checks if backend_name is registered as an RPC backend. Args: backend_name (str): string to identify the RPC backend. Returns: True if the backend has been registered with ``register_backend``, else False. """ return backend_name in BackendType.__members__.keys() def register_backend( backend_name, construct_rpc_backend_options_handler, init_backend_handler ): """Registers a new RPC backend. Args: backend_name (str): backend string to identify the handler. construct_rpc_backend_options_handler (function): Handler that is invoked when rpc_backend.construct_rpc_backend_options(**dict) is called. init_backend_handler (function): Handler that is invoked when the `_init_rpc_backend()` function is called with a backend. This returns the agent. """ global BackendType if backend_registered(backend_name): raise RuntimeError(f"RPC backend {backend_name}: already registered") # Create a new enum type, `BackendType`, with extended members. existing_enum_dict = {member.name: member.value for member in BackendType} extended_enum_dict = dict( { backend_name: BackendValue( construct_rpc_backend_options_handler=construct_rpc_backend_options_handler, init_backend_handler=init_backend_handler, ) }, **existing_enum_dict ) # Can't handle Function Enum API (mypy bug #9079) BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] # Unable to assign a function a method (mypy bug #2427) BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] if BackendType.__doc__: BackendType.__doc__ = _backend_type_doc return BackendType[backend_name] def construct_rpc_backend_options( backend, rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC, init_method=rpc_constants.DEFAULT_INIT_METHOD, **kwargs ): return backend.value.construct_rpc_backend_options_handler( rpc_timeout, init_method, **kwargs ) def init_backend(backend, *args, **kwargs): return backend.value.init_backend_handler(*args, **kwargs) def _init_process_group(store, rank, world_size): # Initialize ProcessGroup. process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT # We're using a bunch of private APIs here since `new_group` requires the # default group to be initialized. group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout) assert group is not None, "Failed to initialize default ProcessGroup." if (rank != -1) and (rank != group.rank()): raise RuntimeError( f"rank argument {rank} doesn't match pg rank {group.rank()}" ) if (world_size != -1) and (world_size != group.size()): raise RuntimeError( f"world_size argument {world_size} doesn't match pg size {group.size()}" ) return group def _tensorpipe_construct_rpc_backend_options_handler( rpc_timeout, init_method, num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS, _transports=None, _channels=None, **kwargs ): from . import TensorPipeRpcBackendOptions return TensorPipeRpcBackendOptions( rpc_timeout=rpc_timeout, init_method=init_method, num_worker_threads=num_worker_threads, _transports=_transports, _channels=_channels, ) def _tensorpipe_validate_devices(devices, device_count): return all( d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count) for d in devices ) # detect if any worker has invalid device_map configurations, and return # reverse device maps def _tensorpipe_exchange_and_check_all_device_maps( my_name, my_device_count, my_device_maps, my_devices, group ): gathered: List[Tuple[ str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device] ]] = [("", 0, {}, []) for _ in range(group.size())] dist.all_gather_object( gathered, (my_name, my_device_count, my_device_maps, my_devices), group ) all_names = [name for name, _, _, _ in gathered] all_device_counts = {name: count for name, count, _, _ in gathered} all_device_maps = {name: map_ for name, _, map_, _ in gathered} all_devices = {name: devices for name, _, _, devices in gathered} _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices) # passed all checked, construct reverse mapping and get list of devices handled by this agent reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) return reverse_device_maps, my_devices def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True): for node in all_names: devices = all_devices[node] if len(set(devices)) != len(devices): raise ValueError( f"Node {node} has duplicated devices\n" f"devices = {devices}" ) if not _tensorpipe_validate_devices(devices, all_device_counts[node]): raise ValueError( f"Node {node} has devices with invalid indices\n" f"devices = {devices}\n" f"device count = {all_device_counts[node]}" ) for source_node in all_names: # For dynamic group (non-static) do not check the target node name since it may not have joined yet if is_static_group and not set(all_device_maps[source_node].keys()).issubset(all_names): raise ValueError( f"Node {source_node} has invalid target node names in its device maps\n" f"device maps = {all_device_maps[source_node].keys()}\n" f"node names = {all_names}" ) for target_node, map_ in all_device_maps[source_node].items(): if len(set(map_.values())) != len(map_): raise ValueError( f"Node {source_node} has duplicated target devices " f"in its device map for {target_node}\n" f"device map = {map_}" ) if all_devices[source_node]: if not set(map_.keys()).issubset(all_devices[source_node]): raise ValueError( f"Node {source_node} has unexpected source devices " f"in its device map for {target_node}\n" f"device map = {map_}\n" f"devices = {all_devices[source_node]}" ) elif not _tensorpipe_validate_devices( map_.keys(), all_device_counts[source_node] ): raise ValueError( f"Node {source_node} has source devices with invalid indices " f"in its device map for {target_node}\n" f"device map = {map_}\n" f"device count = {all_device_counts[source_node]}" ) if all_devices.get(target_node, []): if not set(map_.values()).issubset(all_devices[target_node]): raise ValueError( f"Node {source_node} has unexpected target devices " f"in its device map for {target_node}\n" f"device map = {map_}\n" f"devices = {all_devices[target_node]}" ) elif target_node in all_device_counts and not _tensorpipe_validate_devices( map_.values(), all_device_counts[target_node] ): raise ValueError( f"Node {source_node} has target devices with invalid indices " f"in its device map for {target_node}\n" f"device map = {map_}\n" f"device count = {all_device_counts[target_node]}" ) def _create_device_list(my_devices, my_device_maps, reverse_device_maps): if not my_devices: devices_set: Set[torch.device] = set() for map_ in my_device_maps.values(): devices_set.update(map_.keys()) for map_ in reverse_device_maps.values(): devices_set.update(map_.keys()) devices_set.discard(torch.device("cpu")) my_devices = list(devices_set) my_devices = sorted(my_devices, key=lambda d: d.index) return my_devices def _create_reverse_mapping(my_name, all_names, all_device_maps): reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {} for node in all_names: if my_name in all_device_maps[node]: reverse_device_maps[node] = { v: k for k, v in all_device_maps[node][my_name].items() } return reverse_device_maps def _get_device_infos(): from . import TensorPipeAgent agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) opts = agent._get_backend_options() device_count = torch.cuda.device_count() if torch.cuda.is_available() and opts.devices: torch.cuda.init() return device_count, opts.device_maps, opts.devices def _set_devices_and_reverse_device_map(agent): from . import TensorPipeAgent agent = cast(TensorPipeAgent, agent) # Group state is retrieved from local agent # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid my_worker_info = agent.get_worker_info() my_name = my_worker_info.name all_worker_infos = agent.get_worker_infos() # One round to get device_maps of all workers and construct reverse device maps all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, [] for worker_info in all_worker_infos: worker_name = worker_info.name if worker_name != my_name: # TODO: make async? device_count, device_map, devices = api.rpc_sync(worker_name, _get_device_infos) else: opts = agent._get_backend_options() device_count, device_map, devices = torch.cuda.device_count(), opts.device_maps, opts.devices all_device_counts[worker_name] = device_count all_device_maps[worker_name] = device_map all_devices[worker_name] = devices all_names.append(worker_name) _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=False) reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps for worker_name in all_names: # Set device list for each worker all_devices[worker_name] = _create_device_list(all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps) api.rpc_sync(worker_name, _update_group_membership, args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True)) def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options): from . import TensorPipeAgent from . import TensorPipeRpcBackendOptions if not isinstance(store, dist.Store): raise TypeError(f"`store` must be a c10d::Store. {store}") if not isinstance( rpc_backend_options, TensorPipeRpcBackendOptions ): raise TypeError( f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}" ) device_count = torch.cuda.device_count() is_static_group = True if world_size else False # world_size is specified so this is a static group (ranks cannot join and leave) if is_static_group: # The agent's join method is required to behave like a barrier and perform # collective operations, for which it relies on a process group, instead of # re-implementing this on top of RPCs. group = _init_process_group(store, rank, world_size) reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps( name, device_count, rpc_backend_options.device_maps, rpc_backend_options.devices, group, ) if torch.cuda.is_available() and devices: # It's necessary to initialize PyTorch CUDA states here (e.g., # CUDACachingAllocator). If this is missing, we could hit errors like # "allocator not initialized", because other processes might send # CUDA-related RPC request to this process before user code in this # process initializes its PyTorch CUDA states. torch.cuda.init() # TODO: add try-except and destroy _agent in all processes if any fails. agent = TensorPipeAgent( store, name, rank, world_size, rpc_backend_options, reverse_device_maps, devices, ) api._init_rpc_states(agent) # Run one dummy round of RPC to initialize channels/transports. Without # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC # on that process before rpc.shutdown(), as the agent initialization can # take longer than 5s. api._all_gather(None, timeout=rpc_backend_options.rpc_timeout) # Need a barrier here to make sure no peers leave before the rank0 finishes # _all_gather group.barrier().wait() return agent # initialization for dynamic rpc (ranks can join and leave) else: with _group_membership_management(store, name, True): # Construct TPAgent with empty reverse_device_map and devices # these properties will be updated after initialization agent = TensorPipeAgent( store, name, rank, world_size, rpc_backend_options, {}, [], ) api._init_rpc_states(agent) try: # Notify all workers in group this rank has joined and set devices and reverse_device_map # This is a synchronous operation that completes once all existing ranks are updated _set_devices_and_reverse_device_map(agent) pass except Exception: api.shutdown() raise return agent register_backend( "TENSORPIPE", _tensorpipe_construct_rpc_backend_options_handler, _tensorpipe_init_backend_handler, )