Spaces:
Runtime error
Runtime error
Modified test runner to dispatch requests in parallel to make use of the fact that there is a lot of wait time for the LLM. Defaulting to 16 threads.
Browse files- pages/030_Test_Runner.py +21 -30
- src/testing.py +58 -2
pages/030_Test_Runner.py
CHANGED
@@ -6,7 +6,7 @@ from random import choices
|
|
6 |
from src.architectures import *
|
7 |
from src.common import generate_group_tag
|
8 |
from src.datatypes import *
|
9 |
-
from src.testing import TestGenerator
|
10 |
from src.st_helpers import st_setup
|
11 |
|
12 |
|
@@ -30,18 +30,10 @@ def display_custom_test():
|
|
30 |
st.write("### Run:")
|
31 |
st.write(f"**{total_tests}** total tests will be run")
|
32 |
if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
architecture = Architecture.get_architecture(arch_name)
|
38 |
-
for q in questions:
|
39 |
-
architecture(ArchitectureRequest(q), trace_tags=[tag, "TestRunner"], trace_comment=comment)
|
40 |
-
num_complete += 1
|
41 |
-
if num_complete == total_tests:
|
42 |
-
progress.empty()
|
43 |
-
else:
|
44 |
-
progress.progress(num_complete / total_tests, f"Run {num_complete} of {total_tests} tests...")
|
45 |
|
46 |
|
47 |
def display_pricing_fact_test():
|
@@ -83,25 +75,24 @@ def display_pricing_fact_test():
|
|
83 |
st.write("### Run:")
|
84 |
st.write(f"**{total_tests}** total tests will be run")
|
85 |
if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
answer_stats = {}
|
90 |
for arch_name in selected_archs:
|
91 |
-
answer_stats[arch_name] = [0, 0] # [
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
98 |
else:
|
99 |
-
answer_stats[
|
100 |
-
|
101 |
-
if num_complete == total_tests:
|
102 |
-
progress.empty()
|
103 |
-
else:
|
104 |
-
progress.progress(num_complete / total_tests, f"Run {num_complete} of {total_tests} tests...")
|
105 |
table_data = []
|
106 |
for arch_name in selected_archs:
|
107 |
correct = answer_stats[arch_name][0]
|
@@ -110,7 +101,7 @@ def display_pricing_fact_test():
|
|
110 |
percent_correct = round(correct / total * 100, 1)
|
111 |
table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
|
112 |
df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
|
113 |
-
st.table(df.assign(
|
114 |
|
115 |
|
116 |
if Architecture.architectures is None:
|
|
|
6 |
from src.architectures import *
|
7 |
from src.common import generate_group_tag
|
8 |
from src.datatypes import *
|
9 |
+
from src.testing import TestGenerator, batch_test
|
10 |
from src.st_helpers import st_setup
|
11 |
|
12 |
|
|
|
30 |
st.write("### Run:")
|
31 |
st.write(f"**{total_tests}** total tests will be run")
|
32 |
if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
|
33 |
+
with st.spinner():
|
34 |
+
questions = TestGenerator.get_random_questions(q_count)
|
35 |
+
batch_test(questions=questions, architectures=selected_archs,
|
36 |
+
trace_tags=[tag, "TestRunner"], trace_comment=comment)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
|
39 |
def display_pricing_fact_test():
|
|
|
75 |
st.write("### Run:")
|
76 |
st.write(f"**{total_tests}** total tests will be run")
|
77 |
if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
|
78 |
+
question_price_pairs = choices(question_price_pairs, k=q_count)
|
79 |
+
question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs}
|
80 |
+
questions = list(question_price_dict.keys())
|
81 |
answer_stats = {}
|
82 |
for arch_name in selected_archs:
|
83 |
+
answer_stats[arch_name] = [0, 0] # [correct, incorrect]
|
84 |
+
|
85 |
+
with st.spinner():
|
86 |
+
results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs,
|
87 |
+
trace_tags=[tag, "TestRunner"], trace_comment=comment)
|
88 |
+
for arch, query, response in results:
|
89 |
+
target_price = question_price_dict[query]
|
90 |
+
answer_price = get_price_from_response(response)
|
91 |
+
if target_price == answer_price:
|
92 |
+
answer_stats[arch][0] += 1
|
93 |
else:
|
94 |
+
answer_stats[arch][1] += 1
|
95 |
+
|
|
|
|
|
|
|
|
|
96 |
table_data = []
|
97 |
for arch_name in selected_archs:
|
98 |
correct = answer_stats[arch_name][0]
|
|
|
101 |
percent_correct = round(correct / total * 100, 1)
|
102 |
table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
|
103 |
df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
|
104 |
+
st.table(df.assign(no_index='').set_index('no_index'))
|
105 |
|
106 |
|
107 |
if Architecture.architectures is None:
|
src/testing.py
CHANGED
@@ -7,13 +7,69 @@ import sqlite3
|
|
7 |
import sys
|
8 |
|
9 |
from huggingface_hub import Repository
|
|
|
10 |
from random import choices
|
11 |
-
from
|
|
|
12 |
|
13 |
-
from src.architectures import Architecture
|
14 |
from src.common import data_dir
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
class TestGenerator:
|
18 |
"""
|
19 |
Wrapper class to hold testing questions and serve up examples
|
|
|
7 |
import sys
|
8 |
|
9 |
from huggingface_hub import Repository
|
10 |
+
from queue import Queue
|
11 |
from random import choices
|
12 |
+
from threading import Thread
|
13 |
+
from typing import Dict, List, Optional, Tuple
|
14 |
|
15 |
+
from src.architectures import Architecture, ArchitectureRequest
|
16 |
from src.common import data_dir
|
17 |
|
18 |
|
19 |
+
class ArchitectureTestWorker(Thread):
|
20 |
+
def __init__(self, work_queue: Queue, worker_name: str, trace_tags: List[str], trace_comment: str):
|
21 |
+
Thread.__init__(self)
|
22 |
+
self.work_queue = work_queue
|
23 |
+
self.worker_name = worker_name
|
24 |
+
self.trace_tags = trace_tags
|
25 |
+
self.trace_comment = trace_comment
|
26 |
+
|
27 |
+
def run(self):
|
28 |
+
running: bool = True
|
29 |
+
while running:
|
30 |
+
arch, request = self.work_queue.get()
|
31 |
+
try:
|
32 |
+
if arch is None:
|
33 |
+
running = False
|
34 |
+
else:
|
35 |
+
print(f'{self.worker_name} running "{request.request}" through {arch}')
|
36 |
+
architecture = Architecture.get_architecture(arch)
|
37 |
+
architecture(request, trace_tags=self.trace_tags, trace_comment=self.trace_comment)
|
38 |
+
finally:
|
39 |
+
self.work_queue.task_done()
|
40 |
+
|
41 |
+
|
42 |
+
def batch_test(questions: List[str], architectures: List[str], trace_comment: str = "",
|
43 |
+
trace_tags: List[str] = [], num_workers: int = 16) -> List[Tuple[str, str, str]]:
|
44 |
+
"""
|
45 |
+
Creates a worked pool and dispatches the questions, returnin the answers per architecture, question
|
46 |
+
:param questions: A list of the questions
|
47 |
+
:param architectures: A list of the names of the architectures
|
48 |
+
:param num_workers: The number of works to run
|
49 |
+
:return: A list of Tuples of (arch_name, question, answer)
|
50 |
+
"""
|
51 |
+
queue = Queue()
|
52 |
+
|
53 |
+
question_record: Dict[Tuple[str, str], ArchitectureRequest] = {}
|
54 |
+
for q in questions:
|
55 |
+
for a in architectures:
|
56 |
+
request = ArchitectureRequest(q)
|
57 |
+
question_record[(a, q)] = request
|
58 |
+
queue.put((a, request))
|
59 |
+
|
60 |
+
for i in range(num_workers):
|
61 |
+
worker = ArchitectureTestWorker(work_queue=queue, worker_name=f'Worker {i+1}',
|
62 |
+
trace_tags=trace_tags, trace_comment=trace_comment)
|
63 |
+
worker.daemon = True
|
64 |
+
worker.start()
|
65 |
+
queue.put((None, None)) # Flag to finish
|
66 |
+
|
67 |
+
queue.join()
|
68 |
+
|
69 |
+
# Repackage and return just the list of (arch_name, question, answer)
|
70 |
+
return [(k[0], k[1], v.response) for k, v in question_record.items()]
|
71 |
+
|
72 |
+
|
73 |
class TestGenerator:
|
74 |
"""
|
75 |
Wrapper class to hold testing questions and serve up examples
|