llm-arch / pages /030_Test_Runner.py
alfraser's picture
Updated from using random.choices to random.sample throughout where I need a random distinct set as choices does replacement so you can get the same item twice. Discovered in pricing testing.
b897a48
raw
history blame
No virus
5.42 kB
import regex as re
import streamlit as st
from pandas import DataFrame
from random import sample
from src.architectures import *
from src.common import generate_group_tag
from src.datatypes import *
from src.testing import TestGenerator, batch_test
from src.st_helpers import st_setup
# Componentise different test options
def display_custom_test():
st.write("## Run a new custom test")
st.write("### Comment:")
comment = st.text_input("Optional comment for the test", key="custom_test_comment")
st.write("### Architectures to include:")
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="custom_test_archs")
st.write("### Number of questions to ask:")
q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1, key="custom_q_count")
st.write("### Number of threads to use for testsing:")
thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="custom_thread_slider")
st.write("### Tag:")
tag = generate_group_tag()
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
total_tests = len(selected_archs) * q_count
st.write("### Run:")
st.write(f"**{total_tests}** total tests will be run")
if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"):
with st.spinner():
questions = TestGenerator.get_random_questions(q_count)
batch_test(questions=questions, architectures=selected_archs,
trace_tags=[tag, "TestRunner"], trace_comment=comment,
num_workers=thread_count)
def display_pricing_fact_test():
def get_question_price_pairs():
DataLoader.load_data()
pairs = []
for p in Product.all.values():
price = p.price
product_name = p.name
category_name = p.category.lower_singular_name
if category_name == "tv":
category_name = "TV"
question = f'How much is the {product_name} {category_name}?'
pairs.append((question, price))
return pairs
def get_price_from_response(response: str) -> float:
prices = re.findall('\$[,\d]+\.\d\d', response)
if len(prices) == 0:
return -0.1
return float(prices[0][1:].replace(',',''))
st.write("## Run a pricing fact test")
st.write("### Comment:")
comment = st.text_input("Optional comment for the test", key="pricing_test_comment")
st.write("### Architectures to include:")
selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="pricing_test_archs")
question_price_pairs = get_question_price_pairs()
st.write("### Number of questions to ask:")
q_count = st.slider(label="Number of questions", min_value=1, max_value=len(question_price_pairs), step=1, key="pricing_q_count")
st.write("### Number of threads to use for testsing:")
thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="pricing_thread_slider")
st.write("### Tag:")
tag = generate_group_tag()
st.write(f'Test will be tagged as "{tag}" - record this for easy searching later')
total_tests = len(selected_archs) * q_count
st.write("### Run:")
st.write(f"**{total_tests}** total tests will be run")
if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"):
question_price_pairs = sample(question_price_pairs, k=q_count)
question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs}
questions = list(question_price_dict.keys())
answer_stats = {}
for arch_name in selected_archs:
answer_stats[arch_name] = [0, 0] # [correct, incorrect]
with st.spinner():
results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs,
trace_tags=[tag, "TestRunner"], trace_comment=comment,
num_workers=thread_count)
for arch, query, response in results:
target_price = question_price_dict[query]
answer_price = get_price_from_response(response)
if target_price == answer_price:
answer_stats[arch][0] += 1
else:
answer_stats[arch][1] += 1
table_data = []
for arch_name in selected_archs:
correct = answer_stats[arch_name][0]
incorrect = answer_stats[arch_name][1]
total = correct + incorrect
percent_correct = round(correct / total * 100, 1)
table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%'])
df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct'])
st.table(df.assign(no_index='').set_index('no_index'))
if Architecture.architectures is None:
Architecture.load_architectures()
if st_setup('LLM Arch'):
st.write("# Test Runner")
with st.expander("Pricing Fact Tests"):
display_pricing_fact_test()
with st.expander("Custom Tests"):
display_custom_test()