FIRE / src /serve /launch_all_serve.py
zhangbofei
feat: change to fstchat
6dc0c9c
raw
history blame
8.45 kB
"""
Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022"
Workers are listed in format of `model-path`@`host`@`port`
The key mechanism behind this scripts is:
1, execute shell cmd to launch the controller/worker/openai-api-server;
2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly.
Note that a few of non-critical `fastchat.serve` cmd options are not supported currently.
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import subprocess
import re
import argparse
LOGDIR = "./logs/"
if not os.path.exists(LOGDIR):
os.makedirs(LOGDIR)
parser = argparse.ArgumentParser()
# ------multi worker-----------------
parser.add_argument(
"--model-path-address",
default="THUDM/chatglm2-6b@localhost@20002",
nargs="+",
type=str,
help="model path, host, and port, formatted as model-path@host@port",
)
# ---------------controller-------------------------
parser.add_argument("--controller-host", type=str, default="localhost")
parser.add_argument("--controller-port", type=int, default=21001)
parser.add_argument(
"--dispatch-method",
type=str,
choices=["lottery", "shortest_queue"],
default="shortest_queue",
)
controller_args = ["controller-host", "controller-port", "dispatch-method"]
# ----------------------worker------------------------------------------
parser.add_argument("--worker-host", type=str, default="localhost")
parser.add_argument("--worker-port", type=int, default=21002)
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
# parser.add_argument(
# "--controller-address", type=str, default="http://localhost:21001"
# )
parser.add_argument(
"--model-path",
type=str,
default="lmsys/vicuna-7b-v1.5",
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
"--revision",
type=str,
default="main",
help="Hugging Face Hub model revision identifier",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda", "mps", "xpu", "npu"],
default="cuda",
help="The device type",
)
parser.add_argument(
"--gpus",
type=str,
default="0",
help="A single GPU like 1 or multiple GPUs like 0,2",
)
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument(
"--max-gpu-memory",
type=str,
help="The maximum memory per gpu. Use a string like '13Gib'",
)
parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization")
parser.add_argument(
"--cpu-offloading",
action="store_true",
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
)
parser.add_argument(
"--gptq-ckpt",
type=str,
default=None,
help="Load quantized model. The path to the local GPTQ checkpoint.",
)
parser.add_argument(
"--gptq-wbits",
type=int,
default=16,
choices=[2, 3, 4, 8, 16],
help="#bits to use for quantization",
)
parser.add_argument(
"--gptq-groupsize",
type=int,
default=-1,
help="Groupsize to use for quantization; default uses full row.",
)
parser.add_argument(
"--gptq-act-order",
action="store_true",
help="Whether to apply the activation order GPTQ heuristic",
)
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
help="Optional display comma separated names",
)
parser.add_argument(
"--limit-worker-concurrency",
type=int,
default=5,
help="Limit the model concurrency to prevent OOM.",
)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
worker_args = [
"worker-host",
"worker-port",
"model-path",
"revision",
"device",
"gpus",
"num-gpus",
"max-gpu-memory",
"load-8bit",
"cpu-offloading",
"gptq-ckpt",
"gptq-wbits",
"gptq-groupsize",
"gptq-act-order",
"model-names",
"limit-worker-concurrency",
"stream-interval",
"no-register",
"controller-address",
]
# -----------------openai server---------------------------
parser.add_argument("--server-host", type=str, default="localhost", help="host name")
parser.add_argument("--server-port", type=int, default=8001, help="port number")
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
# parser.add_argument(
# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
# )
# parser.add_argument(
# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
# )
# parser.add_argument(
# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
# )
parser.add_argument(
"--api-keys",
type=lambda s: s.split(","),
help="Optional list of comma separated API keys",
)
server_args = [
"server-host",
"server-port",
"allow-credentials",
"api-keys",
"controller-address",
]
args = parser.parse_args()
args = argparse.Namespace(
**vars(args),
**{"controller-address": f"http://{args.controller_host}:{args.controller_port}"},
)
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
# 0,controller, model_worker, openai_api_server
# 1, cmd options
# 2,LOGDIR
# 3, log file name
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
# 0 LOGDIR
#! 1 log file name
# 2 controller, worker, openai_api_server
base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
sleep 1s;
echo "wait {2} running"
done
echo '{2} running' """
def string_args(args, args_list):
args_str = ""
for key, value in args._get_kwargs():
key = key.replace("_", "-")
if key not in args_list:
continue
key = key.split("-")[-1] if re.search("port|host", key) else key
if not value:
pass
# 1==True -> True
elif isinstance(value, bool) and value == True:
args_str += f" --{key} "
elif (
isinstance(value, list)
or isinstance(value, tuple)
or isinstance(value, set)
):
value = " ".join(value)
args_str += f" --{key} {value} "
else:
args_str += f" --{key} {value} "
return args_str
def launch_worker(item):
log_name = (
item.split("/")[-1]
.split("\\")[-1]
.replace("-", "_")
.replace("@", "_")
.replace(".", "_")
)
args.model_path, args.worker_host, args.worker_port = item.split("@")
print("*" * 80)
worker_str_args = string_args(args, worker_args)
print(worker_str_args)
worker_sh = base_launch_sh.format(
"model_worker", worker_str_args, LOGDIR, f"worker_{log_name}"
)
worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker")
subprocess.run(worker_sh, shell=True, check=True)
subprocess.run(worker_check_sh, shell=True, check=True)
def launch_all():
controller_str_args = string_args(args, controller_args)
controller_sh = base_launch_sh.format(
"controller", controller_str_args, LOGDIR, "controller"
)
controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller")
subprocess.run(controller_sh, shell=True, check=True)
subprocess.run(controller_check_sh, shell=True, check=True)
if isinstance(args.model_path_address, str):
launch_worker(args.model_path_address)
else:
for idx, item in enumerate(args.model_path_address):
print(f"loading {idx}th model:{item}")
launch_worker(item)
server_str_args = string_args(args, server_args)
server_sh = base_launch_sh.format(
"openai_api_server", server_str_args, LOGDIR, "openai_api_server"
)
server_check_sh = base_check_sh.format(
LOGDIR, "openai_api_server", "openai_api_server"
)
subprocess.run(server_sh, shell=True, check=True)
subprocess.run(server_check_sh, shell=True, check=True)
if __name__ == "__main__":
launch_all()