alfraser commited on
Commit
bb7db2c
·
1 Parent(s): e999f4f

Modified test runner to dispatch requests in parallel to make use of the fact that there is a lot of wait time for the LLM. Defaulting to 16 threads.

Browse files
Files changed (2) hide show
  1. pages/030_Test_Runner.py +21 -30
  2. src/testing.py +58 -2
pages/030_Test_Runner.py CHANGED
@@ -6,7 +6,7 @@ from random import choices
6
  from src.architectures import *
7
  from src.common import generate_group_tag
8
  from src.datatypes import *
9
- from src.testing import TestGenerator
10
  from src.st_helpers import st_setup
11
 
12
 
@@ -30,18 +30,10 @@ def display_custom_test():
30
  st.write("### Run:")
31
  st.write(f"**{total_tests}** total tests will be run")
32
  if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
33
- progress = st.progress(0.0, text="Running tests...")
34
- questions = TestGenerator.get_random_questions(q_count)
35
- num_complete = 0
36
- for arch_name in selected_archs:
37
- architecture = Architecture.get_architecture(arch_name)
38
- for q in questions:
39
- architecture(ArchitectureRequest(q), trace_tags=[tag, "TestRunner"], trace_comment=comment)
40
- num_complete += 1
41
- if num_complete == total_tests:
42
- progress.empty()
43
- else:
44
- progress.progress(num_complete / total_tests, f"Run {num_complete} of {total_tests} tests...")
45
 
46
 
47
  def display_pricing_fact_test():
@@ -83,25 +75,24 @@ def display_pricing_fact_test():
83
  st.write("### Run:")
84
  st.write(f"**{total_tests}** total tests will be run")
85
  if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
86
- progress = st.progress(0.0, text="Running tests...")
87
- questions = choices(question_price_pairs, k=q_count)
88
- num_complete = 0
89
  answer_stats = {}
90
  for arch_name in selected_archs:
91
- answer_stats[arch_name] = [0, 0] # [Correct, Incorrect] only used locally here
92
- architecture = Architecture.get_architecture(arch_name)
93
- for question, price in questions:
94
- request = ArchitectureRequest(question)
95
- architecture(request, trace_tags=[tag, "TestRunner"], trace_comment=comment)
96
- if price == get_price_from_response(request.response):
97
- answer_stats[arch_name][0] += 1
 
 
 
98
  else:
99
- answer_stats[arch_name][1] += 1
100
- num_complete += 1
101
- if num_complete == total_tests:
102
- progress.empty()
103
- else:
104
- progress.progress(num_complete / total_tests, f"Run {num_complete} of {total_tests} tests...")
105
  table_data = []
106
  for arch_name in selected_archs:
107
  correct = answer_stats[arch_name][0]
@@ -110,7 +101,7 @@ def display_pricing_fact_test():
110
  percent_correct = round(correct / total * 100, 1)
111
  table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
112
  df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
113
- st.table(df.assign(hack='').set_index('hack'))
114
 
115
 
116
  if Architecture.architectures is None:
 
6
  from src.architectures import *
7
  from src.common import generate_group_tag
8
  from src.datatypes import *
9
+ from src.testing import TestGenerator, batch_test
10
  from src.st_helpers import st_setup
11
 
12
 
 
30
  st.write("### Run:")
31
  st.write(f"**{total_tests}** total tests will be run")
32
  if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
33
+ with st.spinner():
34
+ questions = TestGenerator.get_random_questions(q_count)
35
+ batch_test(questions=questions, architectures=selected_archs,
36
+ trace_tags=[tag, "TestRunner"], trace_comment=comment)
 
 
 
 
 
 
 
 
37
 
38
 
39
  def display_pricing_fact_test():
 
75
  st.write("### Run:")
76
  st.write(f"**{total_tests}** total tests will be run")
77
  if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
78
+ question_price_pairs = choices(question_price_pairs, k=q_count)
79
+ question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs}
80
+ questions = list(question_price_dict.keys())
81
  answer_stats = {}
82
  for arch_name in selected_archs:
83
+ answer_stats[arch_name] = [0, 0] # [correct, incorrect]
84
+
85
+ with st.spinner():
86
+ results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs,
87
+ trace_tags=[tag, "TestRunner"], trace_comment=comment)
88
+ for arch, query, response in results:
89
+ target_price = question_price_dict[query]
90
+ answer_price = get_price_from_response(response)
91
+ if target_price == answer_price:
92
+ answer_stats[arch][0] += 1
93
  else:
94
+ answer_stats[arch][1] += 1
95
+
 
 
 
 
96
  table_data = []
97
  for arch_name in selected_archs:
98
  correct = answer_stats[arch_name][0]
 
101
  percent_correct = round(correct / total * 100, 1)
102
  table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
103
  df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
104
+ st.table(df.assign(no_index='').set_index('no_index'))
105
 
106
 
107
  if Architecture.architectures is None:
src/testing.py CHANGED
@@ -7,13 +7,69 @@ import sqlite3
7
  import sys
8
 
9
  from huggingface_hub import Repository
 
10
  from random import choices
11
- from typing import List, Dict, Optional
 
12
 
13
- from src.architectures import Architecture
14
  from src.common import data_dir
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class TestGenerator:
18
  """
19
  Wrapper class to hold testing questions and serve up examples
 
7
  import sys
8
 
9
  from huggingface_hub import Repository
10
+ from queue import Queue
11
  from random import choices
12
+ from threading import Thread
13
+ from typing import Dict, List, Optional, Tuple
14
 
15
+ from src.architectures import Architecture, ArchitectureRequest
16
  from src.common import data_dir
17
 
18
 
19
+ class ArchitectureTestWorker(Thread):
20
+ def __init__(self, work_queue: Queue, worker_name: str, trace_tags: List[str], trace_comment: str):
21
+ Thread.__init__(self)
22
+ self.work_queue = work_queue
23
+ self.worker_name = worker_name
24
+ self.trace_tags = trace_tags
25
+ self.trace_comment = trace_comment
26
+
27
+ def run(self):
28
+ running: bool = True
29
+ while running:
30
+ arch, request = self.work_queue.get()
31
+ try:
32
+ if arch is None:
33
+ running = False
34
+ else:
35
+ print(f'{self.worker_name} running "{request.request}" through {arch}')
36
+ architecture = Architecture.get_architecture(arch)
37
+ architecture(request, trace_tags=self.trace_tags, trace_comment=self.trace_comment)
38
+ finally:
39
+ self.work_queue.task_done()
40
+
41
+
42
+ def batch_test(questions: List[str], architectures: List[str], trace_comment: str = "",
43
+ trace_tags: List[str] = [], num_workers: int = 16) -> List[Tuple[str, str, str]]:
44
+ """
45
+ Creates a worked pool and dispatches the questions, returnin the answers per architecture, question
46
+ :param questions: A list of the questions
47
+ :param architectures: A list of the names of the architectures
48
+ :param num_workers: The number of works to run
49
+ :return: A list of Tuples of (arch_name, question, answer)
50
+ """
51
+ queue = Queue()
52
+
53
+ question_record: Dict[Tuple[str, str], ArchitectureRequest] = {}
54
+ for q in questions:
55
+ for a in architectures:
56
+ request = ArchitectureRequest(q)
57
+ question_record[(a, q)] = request
58
+ queue.put((a, request))
59
+
60
+ for i in range(num_workers):
61
+ worker = ArchitectureTestWorker(work_queue=queue, worker_name=f'Worker {i+1}',
62
+ trace_tags=trace_tags, trace_comment=trace_comment)
63
+ worker.daemon = True
64
+ worker.start()
65
+ queue.put((None, None)) # Flag to finish
66
+
67
+ queue.join()
68
+
69
+ # Repackage and return just the list of (arch_name, question, answer)
70
+ return [(k[0], k[1], v.response) for k, v in question_record.items()]
71
+
72
+
73
  class TestGenerator:
74
  """
75
  Wrapper class to hold testing questions and serve up examples