Spaces:
No application file
No application file
"""Benchmarking script to test the throughput of serving workers.""" | |
import argparse | |
import json | |
import requests | |
import threading | |
import time | |
from fastchat.conversation import get_conv_template | |
def main(): | |
if args.worker_address: | |
worker_addr = args.worker_address | |
else: | |
controller_addr = args.controller_address | |
ret = requests.post(controller_addr + "/refresh_all_workers") | |
ret = requests.post(controller_addr + "/list_models") | |
models = ret.json()["models"] | |
models.sort() | |
print(f"Models: {models}") | |
ret = requests.post( | |
controller_addr + "/get_worker_address", json={"model": args.model_name} | |
) | |
worker_addr = ret.json()["address"] | |
print(f"worker_addr: {worker_addr}") | |
if worker_addr == "": | |
return | |
conv = get_conv_template("vicuna_v1.1") | |
conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") | |
prompt_template = conv.get_prompt() | |
prompts = [prompt_template for _ in range(args.n_thread)] | |
headers = {"User-Agent": "fastchat Client"} | |
ploads = [ | |
{ | |
"model": args.model_name, | |
"prompt": prompts[i], | |
"max_new_tokens": args.max_new_tokens, | |
"temperature": 0.0, | |
# "stop": conv.sep, | |
} | |
for i in range(len(prompts)) | |
] | |
def send_request(results, i): | |
if args.test_dispatch: | |
ret = requests.post( | |
controller_addr + "/get_worker_address", json={"model": args.model_name} | |
) | |
thread_worker_addr = ret.json()["address"] | |
else: | |
thread_worker_addr = worker_addr | |
print(f"thread {i} goes to {thread_worker_addr}") | |
response = requests.post( | |
thread_worker_addr + "/worker_generate_stream", | |
headers=headers, | |
json=ploads[i], | |
stream=False, | |
) | |
k = list( | |
response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") | |
) | |
# print(k) | |
response_new_words = json.loads(k[-2].decode("utf-8"))["text"] | |
error_code = json.loads(k[-2].decode("utf-8"))["error_code"] | |
# print(f"=== Thread {i} ===, words: {1}, error code: {error_code}") | |
results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" ")) | |
# use N threads to prompt the backend | |
tik = time.time() | |
threads = [] | |
results = [None] * args.n_thread | |
for i in range(args.n_thread): | |
t = threading.Thread(target=send_request, args=(results, i)) | |
t.start() | |
# time.sleep(0.5) | |
threads.append(t) | |
for t in threads: | |
t.join() | |
print(f"Time (POST): {time.time() - tik} s") | |
# n_words = 0 | |
# for i, response in enumerate(results): | |
# # print(prompt[i].replace(conv.sep, "\n"), end="") | |
# # make sure the streaming finishes at EOS or stopping criteria | |
# k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) | |
# response_new_words = json.loads(k[-2].decode("utf-8"))["text"] | |
# # print(response_new_words) | |
# n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" ")) | |
n_words = sum(results) | |
time_seconds = time.time() - tik | |
print( | |
f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, " | |
f"throughput: {n_words / time_seconds} words/s." | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--controller-address", type=str, default="http://localhost:21001" | |
) | |
parser.add_argument("--worker-address", type=str) | |
parser.add_argument("--model-name", type=str, default="vicuna") | |
parser.add_argument("--max-new-tokens", type=int, default=2048) | |
parser.add_argument("--n-thread", type=int, default=8) | |
parser.add_argument("--test-dispatch", action="store_true") | |
args = parser.parse_args() | |
main() | |