import regex as re import streamlit as st from pandas import DataFrame from random import sample from src.architectures import * from src.common import generate_group_tag from src.datatypes import * from src.testing import TestGenerator, batch_test from src.st_helpers import st_setup # Componentise different test options def display_custom_test(): st.write("## Run a new custom test") st.write("### Comment:") comment = st.text_input("Optional comment for the test", key="custom_test_comment") st.write("### Architectures to include:") selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="custom_test_archs") st.write("### Number of questions to ask:") q_count = st.slider(label="Number of questions", min_value=1, max_value=TestGenerator.question_count(), step=1, key="custom_q_count") st.write("### Number of threads to use for testing:") thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="custom_thread_slider") st.write("### Tag:") tag = generate_group_tag() st.write(f'Test will be tagged as "{tag}" - record this for easy searching later') total_tests = len(selected_archs) * q_count st.write("### Run:") st.write(f"**{total_tests}** total tests will be run") if st.button("**Run**", disabled=(total_tests == 0), key="custom_test_button"): with st.spinner(): questions = TestGenerator.get_random_questions(q_count) batch_test(questions=questions, architectures=selected_archs, trace_tags=[tag, "TestRunner"], trace_comment=comment, num_workers=thread_count) def display_pricing_fact_test(): def get_question_price_pairs(): DataLoader.load_data() pairs = [] for p in Product.all.values(): price = p.price product_name = p.name category_name = p.category.lower_singular_name if category_name == "tv": category_name = "TV" question = f'How much is the {product_name} {category_name}?' pairs.append((question, price)) return pairs def get_price_from_response(response: str) -> float: prices = re.findall('\$[,\d]+\.\d\d', response) if len(prices) == 0: return -0.1 return float(prices[0][1:].replace(',','')) st.write("## Run a pricing fact test") st.write("### Comment:") comment = st.text_input("Optional comment for the test", key="pricing_test_comment") st.write("### Architectures to include:") selected_archs = st.multiselect(label="Architectures", options=[a.name for a in Architecture.architectures], key="pricing_test_archs") question_price_pairs = get_question_price_pairs() st.write("### Number of questions to ask:") q_count = st.slider(label="Number of questions", min_value=1, max_value=len(question_price_pairs), step=1, key="pricing_q_count") st.write("### Number of threads to use for testing:") thread_count = st.slider(label="Number of threads", min_value=1, max_value=64, step=1, value=16, key="pricing_thread_slider") st.write("### Tag:") tag = generate_group_tag() st.write(f'Test will be tagged as "{tag}" - record this for easy searching later') total_tests = len(selected_archs) * q_count st.write("### Run:") st.write(f"**{total_tests}** total tests will be run") if st.button("**Run**", disabled=(total_tests == 0), key="pricing_test_button"): question_price_pairs = sample(question_price_pairs, k=q_count) question_price_dict = {qpp[0]: qpp[1] for qpp in question_price_pairs} questions = list(question_price_dict.keys()) answer_stats = {} for arch_name in selected_archs: answer_stats[arch_name] = [0, 0] # [correct, incorrect] with st.spinner(): results: List[Tuple[str, str, str]] = batch_test(questions=questions, architectures=selected_archs, trace_tags=[tag, "TestRunner"], trace_comment=comment, num_workers=thread_count) for arch, query, response in results: target_price = question_price_dict[query] answer_price = get_price_from_response(response) if target_price == answer_price: answer_stats[arch][0] += 1 else: answer_stats[arch][1] += 1 table_data = [] for arch_name in selected_archs: correct = answer_stats[arch_name][0] incorrect = answer_stats[arch_name][1] total = correct + incorrect percent_correct = round(correct / total * 100, 1) table_data.append([arch_name, correct, incorrect, total, f'{percent_correct:.1f}%']) df = DataFrame(table_data, columns=['Architecture', 'Correct', 'Incorrect', 'Total', '% Correct']) st.table(df.assign(no_index='').set_index('no_index')) if Architecture.architectures is None: Architecture.load_architectures() if st_setup('LLM Arch'): st.write("# Test Runner") with st.expander("Pricing Fact Tests"): display_pricing_fact_test() with st.expander("Custom Tests"): display_custom_test()