''' A script that benchmarks the queue performance, can be used to compare the performance of the queue on a given branch vs the main branch. By default, runs 100 jobs in batches of 20 and prints the average time per job. The inference time for each job (without the network overhead of sending/receiving the data) is 0.5 seconds. Each job sends one of: a text, image, audio, or video input and the output is the same as the input. Navigate to the root directory of the gradio repo and run: >> python scripts/benchmark_queue.py You can specify the number of jobs to run and the batch size with the -n parameter: >> python scripts/benchmark_queue.py -n 1000 The results are printed to the console, but you can specify a path to save the results to with the -o parameter: >> python scripts/benchmark_queue.py -n 1000 -o results.json ''' import argparse import asyncio import json import random import time import pandas as pd import websockets import gradio as gr from gradio_client import media_data def identity_with_sleep(x): time.sleep(0.5) return x with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_txt = gr.Text() output_text = gr.Text() submit_text = gr.Button() submit_text.click(identity_with_sleep, input_txt, output_text, api_name="text") with gr.Column(): input_img = gr.Image() output_img = gr.Image() submit_img = gr.Button() submit_img.click(identity_with_sleep, input_img, output_img, api_name="img") with gr.Column(): input_audio = gr.Audio() output_audio = gr.Audio() submit_audio = gr.Button() submit_audio.click(identity_with_sleep, input_audio, output_audio, api_name="audio") with gr.Column(): input_video = gr.Video() output_video = gr.Video() submit_video = gr.Button() submit_video.click(identity_with_sleep, input_video, output_video, api_name="video") demo.queue(max_size=50).launch(prevent_thread_lock=True, quiet=True) FN_INDEX_TO_DATA = { "text": (0, "A longish text " * 15), "image": (1, media_data.BASE64_IMAGE), "audio": (2, media_data.BASE64_AUDIO), "video": (3, media_data.BASE64_VIDEO) } async def get_prediction(host): async with websockets.connect(host) as ws: completed = False name = random.choice(["image", "text", "audio", "video"]) fn_to_hit, data = FN_INDEX_TO_DATA[name] start = time.time() while not completed: msg = json.loads(await ws.recv()) if msg["msg"] == "send_data": await ws.send(json.dumps({"data": [data], "fn_index": fn_to_hit})) if msg["msg"] == "send_hash": await ws.send(json.dumps({"fn_index": fn_to_hit, "session_hash": "shdce"})) if msg["msg"] == "process_completed": completed = True end = time.time() return {"fn_to_hit": name, "duration": end - start} async def main(host, n_results=100): results = [] while len(results) < n_results: batch_results = await asyncio.gather(*[get_prediction(host) for _ in range(20)]) for result in batch_results: if result: results.append(result) data = pd.DataFrame(results).groupby("fn_to_hit").agg({"mean"}) data.columns = data.columns.get_level_values(0) data = data.reset_index() data = {"fn_to_hit": data["fn_to_hit"].to_list(), "duration": data["duration"].to_list()} return data if __name__ == "__main__": parser = argparse.ArgumentParser(description="Upload a demo to a space") parser.add_argument("-n", "--n_jobs", type=int, help="number of jobs", default=100, required=False) parser.add_argument("-o", "--output", type=str, help="path to write output to", required=False) args = parser.parse_args() host = f"{demo.local_url.replace('http', 'ws')}queue/data" data = asyncio.run(main(host, n_results=args.n_jobs)) data = dict(zip(data["fn_to_hit"], data["duration"])) print(data) if args.output: print("Writing results to:", args.output) json.dump(data, open(args.output, "w"))