alfraser commited on
Commit
c319c31
1 Parent(s): 963fb4a

Added runner for pricing fact checks to assess the level of fact embedding in the latest model

Browse files
Files changed (2) hide show
  1. pages/030_Test_Runner.py +114 -33
  2. src/datatypes.py +6 -0
pages/030_Test_Runner.py CHANGED
@@ -1,45 +1,126 @@
 
1
  import streamlit as st
2
 
 
 
3
  from src.architectures import *
4
  from src.common import generate_group_tag
 
5
  from src.testing import TestGenerator
6
  from src.st_helpers import st_setup
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  if Architecture.architectures is None:
10
  Architecture.load_architectures()
11
 
12
  if st_setup('LLM Arch'):
13
- summary = st.container()
14
- with summary:
15
- st.write("# Test Runner")
16
- st.write("## Run a new test")
17
- st.write("### Comment:")
18
- comment = st.text_input("Optional comment for the test")
19
-
20
- st.write("### Architectures to include:")
21
- selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures])
22
-
23
- st.write("### Number of questions to ask:")
24
- q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1)
25
-
26
- st.write("### Tag:")
27
- tag = generate_group_tag()
28
- st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
29
-
30
- total_tests = len(selected_archs) * q_count
31
- st.write("### Run:")
32
- st.write(f"**{total_tests}** total tests will be run")
33
- if st.button("**Run**", disabled=(total_tests==0)):
34
- progress = st.progress(0.0, text="Running tests...")
35
- questions = TestGenerator.get_random_questions(q_count)
36
- num_complete = 0
37
- for arch_name in selected_archs:
38
- architecture = Architecture.get_architecture(arch_name)
39
- for q in questions:
40
- architecture(ArchitectureRequest(q), trace_tags=[tag, "TestRunner"], trace_comment=comment)
41
- num_complete += 1
42
- if num_complete == total_tests:
43
- progress.empty()
44
- else:
45
- progress.progress(num_complete/total_tests, f"Run {num_complete} of {total_tests} tests...")
 
1
+ import regex as re
2
  import streamlit as st
3
 
4
+ from pandas import DataFrame
5
+ from random import choices
6
  from src.architectures import *
7
  from src.common import generate_group_tag
8
+ from src.datatypes import *
9
  from src.testing import TestGenerator
10
  from src.st_helpers import st_setup
11
 
12
 
13
+ # Componentise different test options
14
+ def display_custom_test():
15
+ st.write("## Run a new custom test")
16
+ st.write("### Comment:")
17
+ comment = st.text_input("Optional comment for the test", key="custom_test_comment")
18
+
19
+ st.write("### Architectures to include:")
20
+ selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="custom_test_archs")
21
+
22
+ st.write("### Number of questions to ask:")
23
+ q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1)
24
+
25
+ st.write("### Tag:")
26
+ tag = generate_group_tag()
27
+ st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
28
+
29
+ total_tests = len(selected_archs) * q_count
30
+ st.write("### Run:")
31
+ st.write(f"**{total_tests}** total tests will be run")
32
+ if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
33
+ progress = st.progress(0.0, text="Running tests...")
34
+ questions = TestGenerator.get_random_questions(q_count)
35
+ num_complete = 0
36
+ for arch_name in selected_archs:
37
+ architecture = Architecture.get_architecture(arch_name)
38
+ for q in questions:
39
+ architecture(ArchitectureRequest(q), trace_tags=[tag, "TestRunner"], trace_comment=comment)
40
+ num_complete += 1
41
+ if num_complete == total_tests:
42
+ progress.empty()
43
+ else:
44
+ progress.progress(num_complete / total_tests, f"Run {num_complete} of {total_tests} tests...")
45
+
46
+
47
+ def display_pricing_fact_test():
48
+ def get_question_price_pairs():
49
+ DataLoader.load_data()
50
+ pairs = []
51
+ for p in Product.all.values():
52
+ price = p.price
53
+ product_name = p.name
54
+ category_name = p.category.lower_singular_name
55
+ if category_name == "tv":
56
+ category_name = "TV"
57
+ question = f'How much is the {product_name} {category_name}?'
58
+ pairs.append((question, price))
59
+ return pairs
60
+
61
+ def get_price_from_response(response: str) -> float:
62
+ prices = re.findall('\$[,\d]+\.\d\d', response)
63
+ if len(prices) == 0:
64
+ print(f"Found no price in response '{response}'")
65
+ return -0.1
66
+ return float(prices[0][1:].replace(',',''))
67
+
68
+ st.write("## Run a pricing fact test")
69
+ st.write("### Comment:")
70
+ comment = st.text_input("Optional comment for the test", key="pricing_test_comment")
71
+
72
+ st.write("### Architectures to include:")
73
+ selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="pricing_test_archs")
74
+
75
+ question_price_pairs = get_question_price_pairs()
76
+ st.write("### Number of questions to ask:")
77
+ q_count = st.slider(label="Number of questions", min_value=1, max_value=len(question_price_pairs), step=1)
78
+
79
+ st.write("### Tag:")
80
+ tag = generate_group_tag()
81
+ st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
82
+
83
+ total_tests = len(selected_archs) * q_count
84
+ st.write("### Run:")
85
+ st.write(f"**{total_tests}** total tests will be run")
86
+ if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
87
+ progress = st.progress(0.0, text="Running tests...")
88
+ questions = choices(question_price_pairs, k=q_count)
89
+ num_complete = 0
90
+ answer_stats = {}
91
+ for arch_name in selected_archs:
92
+ answer_stats[arch_name] = [0, 0] # [Correct, Incorrect] only used locally here
93
+ architecture = Architecture.get_architecture(arch_name)
94
+ for question, price in questions:
95
+ request = ArchitectureRequest(question)
96
+ architecture(request, trace_tags=[tag, "TestRunner"], trace_comment=comment)
97
+ if price == get_price_from_response(request.response):
98
+ answer_stats[arch_name][0] += 1
99
+ else:
100
+ answer_stats[arch_name][1] += 1
101
+ num_complete += 1
102
+ if num_complete == total_tests:
103
+ progress.empty()
104
+ else:
105
+ progress.progress(num_complete / total_tests, f"Run {num_complete} of {total_tests} tests...")
106
+ table_data = []
107
+ for arch_name in selected_archs:
108
+ correct = answer_stats[arch_name][0]
109
+ incorrect = answer_stats[arch_name][1]
110
+ total = correct + incorrect
111
+ percent_correct = round(correct / total * 100, 1)
112
+ table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
113
+ df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
114
+ st.table(df.assign(hack='').set_index('hack'))
115
+
116
+
117
  if Architecture.architectures is None:
118
  Architecture.load_architectures()
119
 
120
  if st_setup('LLM Arch'):
121
+ st.write("# Test Runner")
122
+ with st.expander("Pricing Fact Tests"):
123
+ display_pricing_fact_test()
124
+ with st.expander("Custom Tests"):
125
+ display_custom_test()
126
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/datatypes.py CHANGED
@@ -106,6 +106,12 @@ class Category:
106
  return self.name[:-1] # Clip the s
107
  return self.name
108
 
 
 
 
 
 
 
109
 
110
  class Feature:
111
  all = {}
 
106
  return self.name[:-1] # Clip the s
107
  return self.name
108
 
109
+ @property
110
+ def lower_singular_name(self):
111
+ if self.name[-1] == "s":
112
+ return self.name[:-1].lower() # Clip the s
113
+ return self.name.lower()
114
+
115
 
116
  class Feature:
117
  all = {}