NeuralChat-LLAMA-POC / fastchat /serve /test_throughput.py
lvkaokao
update codes.
5a7ab71
raw history blame
No virus
3.98 kB
"""Benchmarking script to test the throughput of serving workers."""
import argparse
import json
import requests
import threading
import time
from fastchat.conversation import default_conversation
def main():
if args.worker_address:
worker_addr = args.worker_address
else:
controller_addr = args.controller_address
ret = requests.post(controller_addr + "/refresh_all_workers")
ret = requests.post(controller_addr + "/list_models")
models = ret.json()["models"]
models.sort()
print(f"Models: {models}")
ret = requests.post(
controller_addr + "/get_worker_address", json={"model": args.model_name}
)
worker_addr = ret.json()["address"]
print(f"worker_addr: {worker_addr}")
if worker_addr == "":
return
conv = default_conversation.copy()
conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words")
prompt_template = conv.get_prompt()
prompts = [prompt_template for _ in range(args.n_thread)]
headers = {"User-Agent": "fastchat Client"}
ploads = [
{
"model": args.model_name,
"prompt": prompts[i],
"max_new_tokens": args.max_new_tokens,
"temperature": 0.0,
# "stop": conv.sep,
}
for i in range(len(prompts))
]
def send_request(results, i):
if args.test_dispatch:
ret = requests.post(
controller_addr + "/get_worker_address", json={"model": args.model_name}
)
thread_worker_addr = ret.json()["address"]
else:
thread_worker_addr = worker_addr
print(f"thread {i} goes to {thread_worker_addr}")
response = requests.post(
thread_worker_addr + "/worker_generate_stream",
headers=headers,
json=ploads[i],
stream=False,
)
k = list(
response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")
)
# print(k)
response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
error_code = json.loads(k[-2].decode("utf-8"))["error_code"]
# print(f"=== Thread {i} ===, words: {1}, error code: {error_code}")
results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" "))
# use N threads to prompt the backend
tik = time.time()
threads = []
results = [None] * args.n_thread
for i in range(args.n_thread):
t = threading.Thread(target=send_request, args=(results, i))
t.start()
# time.sleep(0.5)
threads.append(t)
for t in threads:
t.join()
print(f"Time (POST): {time.time() - tik} s")
# n_words = 0
# for i, response in enumerate(results):
# # print(prompt[i].replace(conv.sep, "\n"), end="")
# # make sure the streaming finishes at EOS or stopping criteria
# k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
# response_new_words = json.loads(k[-2].decode("utf-8"))["text"]
# # print(response_new_words)
# n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
n_words = sum(results)
time_seconds = time.time() - tik
print(
f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, "
f"throughput: {n_words / time_seconds} words/s."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
parser.add_argument("--worker-address", type=str)
parser.add_argument("--model-name", type=str, default="vicuna")
parser.add_argument("--max-new-tokens", type=int, default=2048)
parser.add_argument("--n-thread", type=int, default=8)
parser.add_argument("--test-dispatch", action="store_true")
args = parser.parse_args()
main()