import streamlit as st import argparse import random import traceback from llm_src.utils.cot.get_prompt import get_prompt from llm_src.utils.decoder import Decoder, answer_cleansing from llm_src.utils.fp_substitution import fp_substitute, get_nums_from_passage from llm_src.utils.solis.solis_solver import try_search from llm_src.utils.solis.helper import * st.set_page_config( page_title="Solis Demo", page_icon="🐈", layout="wide", initial_sidebar_state="expanded", menu_items={ 'About': "Welcome to check Reflection-Of-Thought [website](https://reflection-of-thought.github.io/)!" } ) test_examples = [ "Nancy uploaded 41 pictures to Facebook. She put 37 pics into one album and put the rest into 2 different albums. How many pictures were in each album?", ] def read_markdown(path): with open(path, "r") as f: output = f.read() st.markdown(output, unsafe_allow_html=True) # Set up def get_default_argument(): parser = argparse.ArgumentParser(description="Solis") parser.add_argument("--seed", type=int, default=123) parser.add_argument("--api_time_interval", type=float, default=2) parser.add_argument("--max_length", type=int, default=256) parser.add_argument("--substitute_time", type=int, default=5) parser.add_argument("--dataset", type=str, default="multiarith") parser.add_argument("--direct_answer_trigger_for_fewshot", type=str, default="The answer is") args = parser.parse_args() return args args = get_default_argument() keys = [st.secrets["api_key_0"], st.secrets["api_key_1"], st.secrets["api_key_2"]] decoder = Decoder(keys=keys) # Main Demo st.markdown("# Solis Demo") # Demo description read_markdown('md_src/demo_description.md') col1, _ = st.columns([9,1]) with col1: x = st.text_input( "Ask a question which requries numerical reasoning:", value=test_examples[0] ) button = st.button("Run Solis") if not button: st.stop() random.seed(123) args = get_default_argument() prompt_x = get_prompt() orig_nums, _ = get_nums_from_passage(x) if len(orig_nums) > 3: st.markdown("**Too many operands!**") else: col1, col2 = st.columns(2) orig_x = prompt_x + f"Q: {x}\nA:" # step 0, original predict with col1: prev_finished = False try: orig_z = decoder.decode(args, orig_x) orig_z = answer_cleansing(args, orig_z) prev_finished = True st.markdown(f'**original prediction:** {orig_z}') except Exception as e: st.markdown('too frequent! please retry') traceback.print_exc() # step 1, #TODO skip operand proposal # step 2, substitute if prev_finished: prev_finished = False fp_data_list = fp_substitute(x, args.substitute_time) fp_results = [] for fp_data in fp_data_list: fp_x = prompt_x + f"Q: {fp_data['Question']}\nA:" try: fp_z = decoder.decode(args, fp_x) fp_z = answer_cleansing(args, fp_z) with col1: st.markdown(f"**Substituted Q:**") st.text(f"{fp_data['Question']}") st.markdown(f"**Prediction:** {fp_z}") st.markdown() except Exception as e: traceback.print_exc() fp_results.append({ "fp_nums": fp_data["Alignments"], "fp_z": fp_z, }) # step 3, arith relationship inversion with col1: expr, pred = try_search(args, orig_nums, fp_results) if expr is None: st.markdown('not able to solve.') else: st.subheader(f"Solis Output") st.markdown(f'**expression:{expr}**') st.markdown(f'**results:{pred}**')