from collections import defaultdict import json from langchain_core.documents import Document from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnableParallel from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores.utils import DistanceStrategy from langchain_openai import ChatOpenAI from langchain_pinecone import PineconeVectorStore from pinecone import Pinecone import streamlit as st st.set_page_config(layout="wide", page_title="LegisQA") SS = st.session_state SEED = 292764 CONGRESS_GOV_TYPE_MAP = { "hconres": "house-concurrent-resolution", "hjres": "house-joint-resolution", "hr": "house-bill", "hres": "house-resolution", "s": "senate-bill", "sconres": "senate-concurrent-resolution", "sjres": "senate-joint-resolution", "sres": "senate-resolution", } OPENAI_CHAT_MODELS = [ "gpt-3.5-turbo-0125", "gpt-4-0125-preview", ] PREAMBLE = "You are an expert analyst. Use the following excerpts from US congressional legislation to respond to the user's query." PROMPT_TEMPLATES = { "v1": PREAMBLE + """ If you don't know how to respond, just tell the user. {context} Question: {question}""", "v2": PREAMBLE + """ Each snippet starts with a header that includes a unique snippet number (snippet_num), a legis_id, and a title. Your response should reference particular snippets using legis_id and title. If you don't know how to respond, just tell the user. {context} Question: {question}""", "v3": PREAMBLE + """ Each excerpt starts with a header that includes a legis_id, and a title followed by one or more text snippets. When using text snippets in your response, you should mention the legis_id and title. If you don't know how to respond, just tell the user. {context} Question: {question}""", "v4": PREAMBLE + """ The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", and "snippets" keys. If a snippet is useful in writing part of your response, then mention the "title" and "legis_id" inline as you write. If you don't know how to respond, just tell the user. {context} Query: {question}""", } def get_sponsor_url(bioguide_id: str) -> str: return f"https://bioguide.congress.gov/search/bio/{bioguide_id}" def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str: lt = CONGRESS_GOV_TYPE_MAP[legis_type] return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}" def get_govtrack_url(congress_num: int, legis_type: str, legis_num: int) -> str: return f"https://www.govtrack.us/congress/bills/{int(congress_num)}/{legis_type}{int(legis_num)}" def load_bge_embeddings(): model_name = "BAAI/bge-small-en-v1.5" model_kwargs = {"device": "cpu"} encode_kwargs = {"normalize_embeddings": True} emb_fn = HuggingFaceBgeEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, query_instruction="Represent this question for searching relevant passages: ", ) return emb_fn def load_pinecone_vectorstore(): emb_fn = load_bge_embeddings() pc = Pinecone(api_key=st.secrets["pinecone_api_key"]) index = pc.Index(st.secrets["pinecone_index_name"]) vectorstore = PineconeVectorStore( index=index, embedding=emb_fn, text_key="text", distance_strategy=DistanceStrategy.COSINE, ) return vectorstore def write_outreach_links(): nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy" nomic_map_name = "us-congressional-legislation-s1024o256nomic" nomic_url = f"{nomic_base_url}/{nomic_map_name}/map" hf_url = "https://huggingface.co/hyperdemocracy" st.subheader(":brain: Learn about [hyperdemocracy](https://hyperdemocracy.us)") st.subheader(f":world_map: Visualize with [nomic atlas]({nomic_url})") st.subheader(f":hugging_face: Explore the [huggingface datasets](hf_url)") def group_docs(docs) -> list[tuple[str, list[Document]]]: doc_grps = defaultdict(list) # create legis_id groups for doc in docs: doc_grps[doc.metadata["legis_id"]].append(doc) # sort docs in each group by start index for legis_id in doc_grps.keys(): doc_grps[legis_id] = sorted( doc_grps[legis_id], key=lambda x: x.metadata["start_index"], ) # sort groups by number of docs doc_grps = sorted( tuple(doc_grps.items()), key=lambda x: -len(x[1]), ) return doc_grps def format_docs_v1(docs): """Simple double new line join""" return "\n\n".join([doc.page_content for doc in docs]) def format_docs_v2(docs): """Format with snippet_num, legis_id, and title""" def format_doc(idoc, doc): return "snippet_num: {}\nlegis_id: {}\ntitle: {}\n... {} ...\n".format( idoc, doc.metadata["legis_id"], doc.metadata["title"], doc.page_content, ) snips = [] for idoc, doc in enumerate(docs): txt = format_doc(idoc, doc) snips.append(txt) return "\n===\n".join(snips) def format_docs_v3(docs): def format_header(doc): return "legis_id: {}\ntitle: {}".format( doc.metadata["legis_id"], doc.metadata["title"], ) def format_content(doc): return "... {} ...\n".format( doc.page_content, ) snips = [] doc_grps = group_docs(docs) for legis_id, doc_grp in doc_grps: first_doc = doc_grp[0] head = format_header(first_doc) contents = [] for idoc, doc in enumerate(doc_grp): txt = format_content(doc) contents.append(txt) snips.append("{}\n\n{}".format(head, "\n".join(contents))) return "\n===\n".join(snips) def format_docs_v4(docs): """JSON grouped""" doc_grps = group_docs(docs) out = [] for legis_id, doc_grp in doc_grps: dd = { "legis_id": doc_grp[0].metadata["legis_id"], "title": doc_grp[0].metadata["title"], "snippets": [doc.page_content for doc in doc_grp], } out.append(dd) return json.dumps(out, indent=4) DOC_FORMATTERS = { "v1": format_docs_v1, "v2": format_docs_v2, "v3": format_docs_v3, "v4": format_docs_v4, } def escape_markdown(text): MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$" for char in MD_SPECIAL_CHARS: text = text.replace(char, "\\" + char) return text with st.sidebar: with st.container(border=True): write_outreach_links() st.checkbox("escape markdown in answer", key="response_escape_markdown") with st.expander("Generative Config"): st.selectbox(label="model name", options=OPENAI_CHAT_MODELS, key="model_name") st.slider( "temperature", min_value=0.0, max_value=2.0, value=0.0, key="temperature" ) st.slider("top_p", min_value=0.0, max_value=1.0, value=1.0, key="top_p") with st.expander("Retrieval Config"): st.slider( "Number of chunks to retrieve", min_value=1, max_value=40, value=10, key="n_ret_docs", ) st.text_input("Bill ID (e.g. 118-s-2293)", key="filter_legis_id") st.text_input("Bioguide ID (e.g. R000595)", key="filter_bioguide_id") st.text_input("Congress (e.g. 118)", key="filter_congress_num") with st.expander("Prompt Config"): st.selectbox( label="prompt version", options=PROMPT_TEMPLATES.keys(), index=3, key="prompt_version", ) st.text_area( "prompt template", PROMPT_TEMPLATES[SS["prompt_version"]], height=300, key="prompt_template", ) llm = ChatOpenAI( model_name=SS["model_name"], temperature=SS["temperature"], openai_api_key=st.secrets["openai_api_key"], model_kwargs={"top_p": SS["top_p"], "seed": SEED}, ) vectorstore = load_pinecone_vectorstore() format_docs = DOC_FORMATTERS[SS["prompt_version"]] with st.form("my_form"): st.text_area("Enter question:", key="query") query_submitted = st.form_submit_button("Submit") def get_vectorstore_filter(): vs_filter = {} if SS["filter_legis_id"] != "": vs_filter["legis_id"] = SS["filter_legis_id"] if SS["filter_bioguide_id"] != "": vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"] if SS["filter_congress_num"] != "": vs_filter["congress_num"] = int(SS["filter_congress_num"]) return vs_filter if query_submitted: vs_filter = get_vectorstore_filter() retriever = vectorstore.as_retriever( search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter}, ) prompt = PromptTemplate.from_template(SS["prompt_template"]) rag_chain_from_docs = ( RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) | prompt | llm | StrOutputParser() ) rag_chain_with_source = RunnableParallel( {"context": retriever, "question": RunnablePassthrough()} ).assign(answer=rag_chain_from_docs) out = rag_chain_with_source.invoke(SS["query"]) SS["out"] = out def write_doc_grp(legis_id: str, doc_grp: list[Document]): first_doc = doc_grp[0] congress_gov_url = get_congress_gov_url( first_doc.metadata["congress_num"], first_doc.metadata["legis_type"], first_doc.metadata["legis_num"], ) congress_gov_link = f"[congress.gov]({congress_gov_url})" gov_track_url = get_govtrack_url( first_doc.metadata["congress_num"], first_doc.metadata["legis_type"], first_doc.metadata["legis_num"], ) gov_track_link = f"[govtrack.us]({gov_track_url})" ref = "{} chunks from {}\n\n{}\n\n{} | {}\n\n[{} ({}) ]({})".format( len(doc_grp), first_doc.metadata["legis_id"], first_doc.metadata["title"], congress_gov_link, gov_track_link, first_doc.metadata["sponsor_full_name"], first_doc.metadata["sponsor_bioguide_id"], get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]), ) doc_contents = [ "[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content for doc in doc_grp ] with st.expander(ref): st.write(escape_markdown("\n\n...\n\n".join(doc_contents))) out = SS.get("out") if out: if SS["response_escape_markdown"]: st.info(escape_markdown(out["answer"])) else: st.info(out["answer"]) doc_grps = group_docs(out["context"]) for legis_id, doc_grp in doc_grps: write_doc_grp(legis_id, doc_grp) with st.expander("Debug doc format"): st.text_area("formatted docs", value=format_docs(out["context"]), height=600) # st.write(json.loads(format_docs(out["context"])))