# Copyright (c) Alibaba, Inc. and its affiliates. import importlib.util import os import subprocess import sys from typing import Dict, List, Optional from swift.utils import get_logger logger = get_logger() ROUTE_MAPPING: Dict[str, str] = { 'pt': 'swift.cli.pt', 'sft': 'swift.cli.sft', 'infer': 'swift.cli.infer', 'merge-lora': 'swift.cli.merge_lora', 'web-ui': 'swift.cli.web_ui', 'deploy': 'swift.cli.deploy', 'rollout': 'swift.cli.rollout', 'rlhf': 'swift.cli.rlhf', 'sample': 'swift.cli.sample', 'export': 'swift.cli.export', 'eval': 'swift.cli.eval', 'app': 'swift.cli.app', } def use_torchrun() -> bool: nproc_per_node = os.getenv('NPROC_PER_NODE') nnodes = os.getenv('NNODES') if nproc_per_node is None and nnodes is None: return False return True def get_torchrun_args() -> Optional[List[str]]: if not use_torchrun(): return torchrun_args = [] for env_key in ['NPROC_PER_NODE', 'MASTER_PORT', 'NNODES', 'NODE_RANK', 'MASTER_ADDR']: env_val = os.getenv(env_key) if env_val is None: continue torchrun_args += [f'--{env_key.lower()}', env_val] return torchrun_args def _compat_web_ui(argv): # [compat] method_name = argv[0] if method_name in {'web-ui', 'web_ui'} and ('--model' in argv or '--adapters' in argv or '--ckpt_dir' in argv): argv[0] = 'app' logger.warning('Please use `swift app`.') def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None: route_mapping = route_mapping or ROUTE_MAPPING argv = sys.argv[1:] _compat_web_ui(argv) method_name = argv[0].replace('_', '-') argv = argv[1:] file_path = importlib.util.find_spec(route_mapping[method_name]).origin torchrun_args = get_torchrun_args() python_cmd = sys.executable if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}: args = [python_cmd, file_path, *argv] else: args = [python_cmd, '-m', 'torch.distributed.run', *torchrun_args, file_path, *argv] print(f"run sh: `{' '.join(args)}`", flush=True) result = subprocess.run(args) if result.returncode != 0: sys.exit(result.returncode) if __name__ == '__main__': cli_main()