File size: 6,495 Bytes
315ec00
395a4be
 
315ec00
 
 
395a4be
315ec00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395a4be
315ec00
395a4be
315ec00
 
395a4be
 
315ec00
395a4be
315ec00
 
395a4be
 
 
315ec00
395a4be
315ec00
 
395a4be
315ec00
 
 
395a4be
 
315ec00
 
395a4be
 
 
 
 
 
 
315ec00
395a4be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315ec00
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import csv
import json
import time
import random
import itertools
from statistics import quantiles
import multiprocessing as mp

import tyro

from spitfight.colosseum.client import ControllerClient

CONTROLLER_ADDR = os.environ["COLOSSEUM_CONTROLLER_ADDR"]

PROMPTS = [
    "What is Deep Learning?",
    "Write a poem about life.",
    "What is the basics of Rust?",
    "What is Python's GIL?",
    "What are Go channels and how do they compare with Rust flume channels?",
    "What is the difference between a list and a tuple in Python?",
    "How do I use Python's asyncio.wait?",
    "How do I accurately measure the execution time of a function in Python?",
    "How do I use Python's multiprocessing module?",
    "What is Python's built-in dataclasses module?",
    "How is Python's async/await different from Rust's async/await?",
    "What is Hugging Face Transformers?",
    "Tell me about your capabilities.",
    "When is your knowledge cutoff, and what does it mean?",
    "Explain Machine Learning in simple terms.",
    "Write a song that welcomes new students to the University of Michigan.",
    "Explain how to use the Pydantic library with a single code block.",
    "Write a poem about Jae-Won Chung, God of Computer Science.",
    "Write a poem about the University of Michigan.",
    "How do I get my new AI startup funded?",
    "Explain the notion of zero copy in programming.",
    "Explain the notion of zero knowledge proofs.",
    "Explain the notion of zero trust in cybersecurity.",
    "What is a monad in functional programming?",
    "What is a monad in category theory?",
    "How are monads implemented in both Haskell and OCaml?",
    "What is the difference between a monad and a functor?",
    "What is the difference between a monad and a monoid?",
    "How are monads used in Rust?",
    "What is a good name for a software library that makes ML energy efficient?",
    "What would be some good naming criteria for a tech startup?",
    "What is the opposite of democracy? Explain in detail.",
    "Why are people scared to be contacted by the IRS?",
    "What is fingerstyle guitar?",
    "How do I practice and play fingerstyle guitar?",
    "What is the difference between fingerstyle and classical guitar?",
    "What is the difference between classical and flamenco guitar?",
    "What is the difference between classical and jazz guitar?",
    "Explain the basics of the Django web framework.",
    "Explain the basics of the Flask web framework.",
    "Explain the basics of the FastAPI web framework.",
    "I really need to pee. What should I do?",
    "Why would one use Python's abc module?",
    "Explain Python type annotations and why they are useful.",
    "How do I create an immutable list in Python?",
    "How do I create a mutable tuple in Python?",
    "When does dropping out of a Computer Science PhD program make sense?",
    "What is the difference between a PhD and a Masters in Computer Science?",
    "How are software engineers and software developers different?",
    "Hi",
    "What's up",
    "How are you?",
    "What am I supposed to type here",
    "Is indoor vaping legal?",
    "What are the key points of the 14th amendment?",
    "I'm new to the US. What are some social taboos I should be aware of?",
] * 2


def request(prompt: str) -> tuple[str, float, float, float]:
    time.sleep(random.random() * 5)
    client = ControllerClient(CONTROLLER_ADDR, timeout=60)
    
    response_a, response_b = "", ""
    first_token_latency = -1.0
    num_tokens = 0
    start_time = time.monotonic()
    for i, (resp_a, resp_b) in enumerate(itertools.zip_longest(
        client.prompt(prompt, index=0),
        client.prompt(prompt, index=1),
    )):
        if i == 0:
            first_token_latency = time.monotonic() - start_time
        if resp_a is not None:
            num_tokens += 1
            response_a += resp_a
        if resp_b is not None:
            num_tokens += 1
            response_b += resp_b

    latency = time.monotonic() - start_time
    tokens_per_second = num_tokens / latency
    return client.request_id, latency, first_token_latency, tokens_per_second


def main(
    concurrencies: list[int] = [10],
    result_csv: str = "load_test_results.csv",
    ftl_json: str = "ftl_dist.json",
):
    data = []
    ftl_dist = {}

    for concurrency in concurrencies:
        latencies = []
        first_token_latencies = []
        tps = []

        start_time = time.monotonic()
        with mp.Pool(processes=concurrency) as pool:
            for request_id, latency, first_token_latency, tokens_per_second in pool.imap_unordered(request, PROMPTS):
                latencies.append(latency)
                first_token_latencies.append(first_token_latency)
                tps.append(tokens_per_second)
                print(f"Request ID {request_id} finished, {latency=:.2f}s, {first_token_latency=:.2f}s, {tokens_per_second=:.2f} tokens/s")

        total_time = time.monotonic() - start_time
        average_latency = sum(latencies) / len(latencies)
        average_first_token_latency = sum(first_token_latencies) / len(first_token_latencies)
        first_token_latency_quartiles = quantiles(first_token_latencies, n=10)
        ftl_dist[concurrency] = first_token_latencies
        average_tokens_per_second = sum(tps) / len(tps)
        requests_per_second = len(latencies) / total_time
        print(f"Total time: {total_time:.2f}s")
        print(f"Average latency: {average_latency:.2f}s")
        print(f"Average first token latency: {average_first_token_latency:.2f}s")
        print(f"Average tokens per second: {average_tokens_per_second:.2f}")
        print(f"Requests per second: {requests_per_second:.2f}")
        print(f"First token latency quartiles: {first_token_latency_quartiles}")
        data.append((
            concurrency,
            total_time,
            average_latency,
            average_first_token_latency,
            average_tokens_per_second,
            requests_per_second,
        ))

    with open(result_csv, "w") as f:
        writer = csv.writer(f)
        writer.writerow((
            "concurrency",
            "total_time",
            "average_latency",
            "average_first_token_latency",
            "average_tokens_per_second",
            "requests_per_second",
        ))
        writer.writerows(data)

    with open(ftl_json, "w") as f:
        json.dump(ftl_dist, f)


if __name__ == "__main__":
    tyro.cli(main)