Spaces:
Sleeping
Sleeping
| import os | |
| import socket | |
| from functools import cache, partial, wraps | |
| from typing import Callable | |
| import deepspeed | |
| import torch | |
| from deepspeed.accelerator import get_accelerator | |
| from torch.distributed import broadcast_object_list | |
| def get_free_port(): | |
| sock = socket.socket() | |
| sock.bind(("", 0)) | |
| return sock.getsockname()[1] | |
| def fix_unset_envs(): | |
| envs = dict(RANK="0", WORLD_SIZE="1", MASTER_ADDR="localhost", MASTER_PORT=str(get_free_port()), LOCAL_RANK="0") | |
| for key in envs: | |
| value = os.getenv(key) | |
| if value is not None: | |
| return | |
| for key, value in envs.items(): | |
| os.environ[key] = value | |
| def init_distributed(): | |
| fix_unset_envs() | |
| deepspeed.init_distributed(get_accelerator().communication_backend_name()) | |
| torch.cuda.set_device(local_rank()) | |
| def local_rank(): | |
| return int(os.getenv("LOCAL_RANK", 0)) | |
| def global_rank(): | |
| return int(os.getenv("RANK", 0)) | |
| def is_local_leader(): | |
| return local_rank() == 0 | |
| def is_global_leader(): | |
| return global_rank() == 0 | |
| def leader_only(leader_only_type, fn: Callable | None = None, boardcast_return=False) -> Callable: | |
| """ | |
| Args: | |
| fn: The function to decorate | |
| boardcast_return: Whether to boardcast the return value to all processes | |
| (may cause deadlock if the function calls another decorated function) | |
| """ | |
| def wrapper(fn): | |
| if hasattr(fn, "__leader_only_type__"): | |
| raise RuntimeError(f"Function {fn.__name__} has already been decorated with {fn.__leader_only_type__}") | |
| fn.__leader_only_type__ = leader_only_type | |
| if leader_only_type == "local": | |
| guard_fn = is_local_leader | |
| elif leader_only_type == "global": | |
| guard_fn = is_global_leader | |
| else: | |
| raise ValueError(f"Unknown leader_only_type: {leader_only_type}") | |
| def wrapped(*args, **kwargs): | |
| if boardcast_return: | |
| init_distributed() | |
| obj_list = [None] | |
| if guard_fn(): | |
| ret = fn(*args, **kwargs) | |
| obj_list[0] = ret | |
| if boardcast_return: | |
| broadcast_object_list(obj_list, src=0) | |
| return obj_list[0] | |
| return wrapped | |
| if fn is None: | |
| return wrapper | |
| return wrapper(fn) | |
| local_leader_only = partial(leader_only, "local") | |
| global_leader_only = partial(leader_only, "global") | |