|
''' |
|
A script that benchmarks the queue performance, can be used to compare the performance |
|
of the queue on a given branch vs the main branch. By default, runs 100 jobs in batches |
|
of 20 and prints the average time per job. The inference time for each job (without the |
|
network overhead of sending/receiving the data) is 0.5 seconds. Each job sends one of: |
|
a text, image, audio, or video input and the output is the same as the input. |
|
|
|
Navigate to the root directory of the gradio repo and run: |
|
>> python scripts/benchmark_queue.py |
|
|
|
You can specify the number of jobs to run and the batch size with the -n parameter: |
|
>> python scripts/benchmark_queue.py -n 1000 |
|
|
|
The results are printed to the console, but you can specify a path to save the results |
|
to with the -o parameter: |
|
>> python scripts/benchmark_queue.py -n 1000 -o results.json |
|
''' |
|
|
|
import argparse |
|
import asyncio |
|
import json |
|
import random |
|
import time |
|
|
|
import pandas as pd |
|
import websockets |
|
|
|
import gradio as gr |
|
from gradio_client import media_data |
|
|
|
|
|
def identity_with_sleep(x): |
|
time.sleep(0.5) |
|
return x |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_txt = gr.Text() |
|
output_text = gr.Text() |
|
submit_text = gr.Button() |
|
submit_text.click(identity_with_sleep, input_txt, output_text, api_name="text") |
|
with gr.Column(): |
|
input_img = gr.Image() |
|
output_img = gr.Image() |
|
submit_img = gr.Button() |
|
submit_img.click(identity_with_sleep, input_img, output_img, api_name="img") |
|
with gr.Column(): |
|
input_audio = gr.Audio() |
|
output_audio = gr.Audio() |
|
submit_audio = gr.Button() |
|
submit_audio.click(identity_with_sleep, input_audio, output_audio, api_name="audio") |
|
with gr.Column(): |
|
input_video = gr.Video() |
|
output_video = gr.Video() |
|
submit_video = gr.Button() |
|
submit_video.click(identity_with_sleep, input_video, output_video, api_name="video") |
|
demo.queue(max_size=50).launch(prevent_thread_lock=True, quiet=True) |
|
|
|
|
|
FN_INDEX_TO_DATA = { |
|
"text": (0, "A longish text " * 15), |
|
"image": (1, media_data.BASE64_IMAGE), |
|
"audio": (2, media_data.BASE64_AUDIO), |
|
"video": (3, media_data.BASE64_VIDEO) |
|
} |
|
|
|
|
|
async def get_prediction(host): |
|
async with websockets.connect(host) as ws: |
|
completed = False |
|
name = random.choice(["image", "text", "audio", "video"]) |
|
fn_to_hit, data = FN_INDEX_TO_DATA[name] |
|
start = time.time() |
|
|
|
while not completed: |
|
msg = json.loads(await ws.recv()) |
|
if msg["msg"] == "send_data": |
|
await ws.send(json.dumps({"data": [data], "fn_index": fn_to_hit})) |
|
if msg["msg"] == "send_hash": |
|
await ws.send(json.dumps({"fn_index": fn_to_hit, "session_hash": "shdce"})) |
|
if msg["msg"] == "process_completed": |
|
completed = True |
|
end = time.time() |
|
return {"fn_to_hit": name, "duration": end - start} |
|
|
|
|
|
async def main(host, n_results=100): |
|
results = [] |
|
while len(results) < n_results: |
|
batch_results = await asyncio.gather(*[get_prediction(host) for _ in range(20)]) |
|
for result in batch_results: |
|
if result: |
|
results.append(result) |
|
|
|
data = pd.DataFrame(results).groupby("fn_to_hit").agg({"mean"}) |
|
data.columns = data.columns.get_level_values(0) |
|
data = data.reset_index() |
|
data = {"fn_to_hit": data["fn_to_hit"].to_list(), "duration": data["duration"].to_list()} |
|
return data |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Upload a demo to a space") |
|
parser.add_argument("-n", "--n_jobs", type=int, help="number of jobs", default=100, required=False) |
|
parser.add_argument("-o", "--output", type=str, help="path to write output to", required=False) |
|
args = parser.parse_args() |
|
|
|
host = f"{demo.local_url.replace('http', 'ws')}queue/data" |
|
data = asyncio.run(main(host, n_results=args.n_jobs)) |
|
data = dict(zip(data["fn_to_hit"], data["duration"])) |
|
|
|
print(data) |
|
|
|
if args.output: |
|
print("Writing results to:", args.output) |
|
json.dump(data, open(args.output, "w")) |
|
|