File size: 4,300 Bytes
a03b3ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
'''
A script that benchmarks the queue performance, can be used to compare the performance 
of the queue on a given branch vs the main branch. By default, runs 100 jobs in batches
of 20 and prints the average time per job. The inference time for each job (without the
network overhead of sending/receiving the data) is 0.5 seconds. Each job sends one of:
a text, image, audio, or video input and the output is the same as the input.

Navigate to the root directory of the gradio repo and run:
>> python scripts/benchmark_queue.py

You can specify the number of jobs to run and the batch size with the -n parameter:
>> python scripts/benchmark_queue.py -n 1000

The results are printed to the console, but you can specify a path to save the results 
to with the -o parameter:
>> python scripts/benchmark_queue.py -n 1000 -o results.json
'''

import argparse
import asyncio
import json
import random
import time

import pandas as pd
import websockets

import gradio as gr
from gradio_client import media_data


def identity_with_sleep(x):
    time.sleep(0.5)
    return x


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            input_txt = gr.Text()
            output_text = gr.Text()
            submit_text = gr.Button()
            submit_text.click(identity_with_sleep, input_txt, output_text, api_name="text")
        with gr.Column():
            input_img = gr.Image()
            output_img = gr.Image()
            submit_img = gr.Button()
            submit_img.click(identity_with_sleep, input_img, output_img, api_name="img")
        with gr.Column():
            input_audio = gr.Audio()
            output_audio = gr.Audio()
            submit_audio = gr.Button()
            submit_audio.click(identity_with_sleep, input_audio, output_audio, api_name="audio")
        with gr.Column():
            input_video = gr.Video()
            output_video = gr.Video()
            submit_video = gr.Button()
            submit_video.click(identity_with_sleep, input_video, output_video, api_name="video")
demo.queue(max_size=50).launch(prevent_thread_lock=True, quiet=True)


FN_INDEX_TO_DATA = {
    "text": (0, "A longish text " * 15),
    "image": (1, media_data.BASE64_IMAGE),
    "audio": (2, media_data.BASE64_AUDIO),
    "video": (3, media_data.BASE64_VIDEO)
}


async def get_prediction(host):
    async with websockets.connect(host) as ws:
        completed = False
        name = random.choice(["image", "text", "audio", "video"])
        fn_to_hit, data = FN_INDEX_TO_DATA[name]
        start = time.time()

        while not completed:
            msg = json.loads(await ws.recv())
            if msg["msg"] == "send_data":
                await ws.send(json.dumps({"data": [data], "fn_index": fn_to_hit}))
            if msg["msg"] == "send_hash":
                await ws.send(json.dumps({"fn_index": fn_to_hit, "session_hash": "shdce"}))
            if msg["msg"] == "process_completed":
                completed = True
                end = time.time()
                return {"fn_to_hit": name, "duration": end - start}


async def main(host, n_results=100):
    results = []
    while len(results) < n_results:
        batch_results = await asyncio.gather(*[get_prediction(host) for _ in range(20)])
        for result in batch_results:
            if result:
                results.append(result)

    data = pd.DataFrame(results).groupby("fn_to_hit").agg({"mean"})
    data.columns = data.columns.get_level_values(0)
    data = data.reset_index()
    data = {"fn_to_hit": data["fn_to_hit"].to_list(), "duration": data["duration"].to_list()}                
    return data


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Upload a demo to a space")
    parser.add_argument("-n", "--n_jobs", type=int, help="number of jobs", default=100, required=False)
    parser.add_argument("-o", "--output", type=str, help="path to write output to", required=False)     
    args = parser.parse_args()

    host = f"{demo.local_url.replace('http', 'ws')}queue/data"
    data = asyncio.run(main(host, n_results=args.n_jobs))
    data = dict(zip(data["fn_to_hit"], data["duration"]))
    
    print(data)
    
    if args.output:
        print("Writing results to:", args.output)
        json.dump(data, open(args.output, "w"))