""" """ from collections import defaultdict import json import os import re from langchain_core.documents import Document from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableParallel from langchain_core.runnables import RunnablePassthrough from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_together import ChatTogether from langchain_google_genai import ChatGoogleGenerativeAI import streamlit as st import utils_mod import doc_format_mod import guide_mod import sidebar_mod import usage_mod import vectorstore_mod st.set_page_config(layout="wide", page_title="LegisQA") os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"] os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"] os.environ["TOKENIZERS_PARALLELISM"] = "false" SS = st.session_state SEED = 292764 CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118] SPONSOR_PARTIES = ["D", "R", "L", "I"] OPENAI_CHAT_MODELS = { "gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}}, "gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}}, } ANTHROPIC_CHAT_MODELS = { "claude-3-haiku-20240307": {"cost": {"pmi": 0.25, "pmo": 1.25}}, "claude-3-5-sonnet-20240620": {"cost": {"pmi": 3.00, "pmo": 15.0}}, "claude-3-opus-20240229": {"cost": {"pmi": 15.0, "pmo": 75.0}}, } TOGETHER_CHAT_MODELS = { "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {"cost": {"pmi": 0.18, "pmo": 0.18}}, "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": { "cost": {"pmi": 0.88, "pmo": 0.88} }, "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": { "cost": {"pmi": 5.00, "pmo": 5.00} }, } GOOGLE_CHAT_MODELS = { "gemini-1.5-flash": {"cost": {"pmi": 0.0, "pmo": 0.0}}, "gemini-1.5-pro": {"cost": {"pmi": 0.0, "pmo": 0.0}}, "gemini-1.5-pro-exp-0801": {"cost": {"pmi": 0.0, "pmo": 0.0}}, } PROVIDER_MODELS = { "OpenAI": OPENAI_CHAT_MODELS, "Anthropic": ANTHROPIC_CHAT_MODELS, "Together": TOGETHER_CHAT_MODELS, "Google": GOOGLE_CHAT_MODELS, } def render_example_queries(): with st.expander("Example Queries"): st.write( """ ``` What are the themes around artificial intelligence? ``` ``` Write a well cited 3 paragraph essay on food insecurity. ``` ``` Create a table summarizing major climate change ideas with columns legis_id, title, idea. ``` ``` Write an action plan to keep social security solvent. ``` ``` Suggest reforms that would benefit the Medicaid program. ``` """ ) def get_generative_config(key_prefix: str) -> dict: output = {} key = "provider" output[key] = st.selectbox( label=key, options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|{key}" ) key = "model_name" output[key] = st.selectbox( label=key, options=PROVIDER_MODELS[output["provider"]], key=f"{key_prefix}|{key}", ) key = "temperature" output[key] = st.slider( key, min_value=0.0, max_value=2.0, value=0.0, key=f"{key_prefix}|{key}", ) key = "max_output_tokens" output[key] = st.slider( key, min_value=1024, max_value=2048, key=f"{key_prefix}|{key}", ) key = "top_p" output[key] = st.slider( key, min_value=0.0, max_value=1.0, value=0.9, key=f"{key_prefix}|{key}" ) key = "should_escape_markdown" output[key] = st.checkbox( key, value=False, key=f"{key_prefix}|{key}", ) key = "should_add_legis_urls" output[key] = st.checkbox( key, value=True, key=f"{key_prefix}|{key}", ) return output def get_retrieval_config(key_prefix: str) -> dict: output = {} key = "n_ret_docs" output[key] = st.slider( "Number of chunks to retrieve", min_value=1, max_value=32, value=8, key=f"{key_prefix}|{key}", ) key = "filter_legis_id" output[key] = st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|{key}") key = "filter_bioguide_id" output[key] = st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|{key}") key = "filter_congress_nums" output[key] = st.multiselect( "Congress Numbers", CONGRESS_NUMBERS, default=CONGRESS_NUMBERS, key=f"{key_prefix}|{key}", ) key = "filter_sponsor_parties" output[key] = st.multiselect( "Sponsor Party", SPONSOR_PARTIES, default=SPONSOR_PARTIES, key=f"{key_prefix}|{key}", ) return output def get_llm(gen_config: dict): match gen_config["provider"]: case "OpenAI": llm = ChatOpenAI( model=gen_config["model_name"], temperature=gen_config["temperature"], api_key=st.secrets["openai_api_key"], top_p=gen_config["top_p"], seed=SEED, max_tokens=gen_config["max_output_tokens"], ) case "Anthropic": llm = ChatAnthropic( model_name=gen_config["model_name"], temperature=gen_config["temperature"], api_key=st.secrets["anthropic_api_key"], top_p=gen_config["top_p"], max_tokens_to_sample=gen_config["max_output_tokens"], ) case "Together": llm = ChatTogether( model=gen_config["model_name"], temperature=gen_config["temperature"], max_tokens=gen_config["max_output_tokens"], top_p=gen_config["top_p"], seed=SEED, api_key=st.secrets["together_api_key"], ) case "Google": llm = ChatGoogleGenerativeAI( model=gen_config["model_name"], temperature=gen_config["temperature"], api_key=st.secrets["google_api_key"], max_output_tokens=gen_config["max_output_tokens"], top_p=gen_config["top_p"], ) case _: raise ValueError() return llm def create_rag_chain(llm, retriever): QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. When citing legis_id, use the same format as the excerpts (e.g. "116-hr-125"). If you don't know how to respond, just tell the user. --- Congressional Legislation Excerpts: {context} --- Query: {query}""" prompt = ChatPromptTemplate.from_messages( [ ("human", QUERY_RAG_TEMPLATE), ] ) rag_chain = ( RunnableParallel( { "docs": retriever, "query": RunnablePassthrough(), } ) .assign(context=lambda x: doc_format_mod.format_docs(x["docs"])) .assign(aimessage=prompt | llm) ) return rag_chain def process_query(gen_config: dict, ret_config: dict, query: str): vectorstore = vectorstore_mod.load_pinecone_vectorstore() llm = get_llm(gen_config) vs_filter = vectorstore_mod.get_vectorstore_filter(ret_config) retriever = vectorstore.as_retriever( search_kwargs={"k": ret_config["n_ret_docs"], "filter": vs_filter}, ) rag_chain = create_rag_chain(llm, retriever) response = rag_chain.invoke(query) return response def render_response( response: dict, model_info: dict, provider: str, should_escape_markdown: bool, should_add_legis_urls: bool, tag: str | None = None, ): response_text = response["aimessage"].content if should_escape_markdown: response_text = utils_mod.escape_markdown(response_text) if should_add_legis_urls: response_text = utils_mod.replace_legis_ids_with_urls(response_text) with st.container(border=True): if tag is None: st.write("Response") else: st.write(f"Response ({tag})") st.info(response_text) usage_mod.display_api_usage(response["aimessage"], model_info, provider, tag=tag) doc_format_mod.render_retrieved_chunks(response["docs"], tag=tag) def render_query_rag_tab(): key_prefix = "query_rag" render_example_queries() with st.form(f"{key_prefix}|query_form"): query = st.text_area( "Enter a query that can be answered with congressional legislation:" ) cols = st.columns(2) with cols[0]: query_submitted = st.form_submit_button("Submit") with cols[1]: status_placeholder = st.empty() col1, col2 = st.columns(2) with col1: with st.expander("Generative Config"): gen_config = get_generative_config(key_prefix) with col2: with st.expander("Retrieval Config"): ret_config = get_retrieval_config(key_prefix) rkey = f"{key_prefix}|response" if query_submitted: with status_placeholder: with st.spinner("generating response"): SS[rkey] = process_query(gen_config, ret_config, query) if response := SS.get(rkey): model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]] render_response( response, model_info, gen_config["provider"], gen_config["should_escape_markdown"], gen_config["should_add_legis_urls"], ) with st.expander("Debug"): st.write(response) def render_query_rag_sbs_tab(): base_key_prefix = "query_rag_sbs" with st.form(f"{base_key_prefix}|query_form"): query = st.text_area( "Enter a query that can be answered with congressional legislation:" ) cols = st.columns(2) with cols[0]: query_submitted = st.form_submit_button("Submit") with cols[1]: status_placeholder = st.empty() grp1a, grp2a = st.columns(2) gen_configs = {} ret_configs = {} with grp1a: st.header("Group 1") key_prefix = f"{base_key_prefix}|grp1" with st.expander("Generative Config"): gen_configs["grp1"] = get_generative_config(key_prefix) with st.expander("Retrieval Config"): ret_configs["grp1"] = get_retrieval_config(key_prefix) with grp2a: st.header("Group 2") key_prefix = f"{base_key_prefix}|grp2" with st.expander("Generative Config"): gen_configs["grp2"] = get_generative_config(key_prefix) with st.expander("Retrieval Config"): ret_configs["grp2"] = get_retrieval_config(key_prefix) grp1b, grp2b = st.columns(2) sbs_cols = {"grp1": grp1b, "grp2": grp2b} grp_names = {"grp1": "Group 1", "grp2": "Group 2"} for post_key_prefix in ["grp1", "grp2"]: with sbs_cols[post_key_prefix]: key_prefix = f"{base_key_prefix}|{post_key_prefix}" rkey = f"{key_prefix}|response" if query_submitted: with status_placeholder: with st.spinner( "generating response for {}".format(grp_names[post_key_prefix]) ): SS[rkey] = process_query( gen_configs[post_key_prefix], ret_configs[post_key_prefix], query, ) if response := SS.get(rkey): model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][ gen_configs[post_key_prefix]["model_name"] ] render_response( response, model_info, gen_configs[post_key_prefix]["provider"], gen_configs[post_key_prefix]["should_escape_markdown"], gen_configs[post_key_prefix]["should_add_legis_urls"], tag=grp_names[post_key_prefix], ) def main(): st.title(":classical_building: LegisQA :classical_building:") st.header("Query Congressional Bills") with st.sidebar: sidebar_mod.render_sidebar() query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs( [ "RAG", "RAG (side-by-side)", "Guide", ] ) with query_rag_tab: render_query_rag_tab() with query_rag_sbs_tab: render_query_rag_sbs_tab() with guide_tab: guide_mod.render_guide() if __name__ == "__main__": main()