hf-legisqa / app.py
gabrielaltay's picture
lets go
80275c5
raw
history blame
11.1 kB
from collections import defaultdict
import json
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableParallel
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_openai import ChatOpenAI
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone
import streamlit as st
st.set_page_config(layout="wide", page_title="LegisQA")
SS = st.session_state
SEED = 292764
CONGRESS_GOV_TYPE_MAP = {
"hconres": "house-concurrent-resolution",
"hjres": "house-joint-resolution",
"hr": "house-bill",
"hres": "house-resolution",
"s": "senate-bill",
"sconres": "senate-concurrent-resolution",
"sjres": "senate-joint-resolution",
"sres": "senate-resolution",
}
OPENAI_CHAT_MODELS = [
"gpt-3.5-turbo-0125",
"gpt-4-0125-preview",
]
PREAMBLE = "You are an expert analyst. Use the following excerpts from US congressional legislation to respond to the user's query."
PROMPT_TEMPLATES = {
"v1": PREAMBLE
+ """ If you don't know how to respond, just tell the user.
{context}
Question: {question}""",
"v2": PREAMBLE
+ """ Each snippet starts with a header that includes a unique snippet number (snippet_num), a legis_id, and a title. Your response should reference particular snippets using legis_id and title. If you don't know how to respond, just tell the user.
{context}
Question: {question}""",
"v3": PREAMBLE
+ """ Each excerpt starts with a header that includes a legis_id, and a title followed by one or more text snippets. When using text snippets in your response, you should mention the legis_id and title. If you don't know how to respond, just tell the user.
{context}
Question: {question}""",
"v4": PREAMBLE
+ """ The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", and "snippets" keys. If a snippet is useful in writing part of your response, then mention the "title" and "legis_id" inline as you write. If you don't know how to respond, just tell the user.
{context}
Query: {question}""",
}
def get_sponsor_url(bioguide_id: str) -> str:
return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str:
lt = CONGRESS_GOV_TYPE_MAP[legis_type]
return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
def get_govtrack_url(congress_num: int, legis_type: str, legis_num: int) -> str:
return f"https://www.govtrack.us/congress/bills/{int(congress_num)}/{legis_type}{int(legis_num)}"
def load_bge_embeddings():
model_name = "BAAI/bge-small-en-v1.5"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
emb_fn = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
query_instruction="Represent this question for searching relevant passages: ",
)
return emb_fn
def load_pinecone_vectorstore():
emb_fn = load_bge_embeddings()
pc = Pinecone(api_key=st.secrets["pinecone_api_key"])
index = pc.Index(st.secrets["pinecone_index_name"])
vectorstore = PineconeVectorStore(
index=index,
embedding=emb_fn,
text_key="text",
distance_strategy=DistanceStrategy.COSINE,
)
return vectorstore
def write_outreach_links():
nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy"
nomic_map_name = "us-congressional-legislation-s1024o256nomic"
nomic_url = f"{nomic_base_url}/{nomic_map_name}/map"
hf_url = "https://huggingface.co/hyperdemocracy"
st.subheader(":brain: Learn about [hyperdemocracy](https://hyperdemocracy.us)")
st.subheader(f":world_map: Visualize with [nomic atlas]({nomic_url})")
st.subheader(f":hugging_face: Explore the [huggingface datasets](hf_url)")
def group_docs(docs) -> list[tuple[str, list[Document]]]:
doc_grps = defaultdict(list)
# create legis_id groups
for doc in docs:
doc_grps[doc.metadata["legis_id"]].append(doc)
# sort docs in each group by start index
for legis_id in doc_grps.keys():
doc_grps[legis_id] = sorted(
doc_grps[legis_id],
key=lambda x: x.metadata["start_index"],
)
# sort groups by number of docs
doc_grps = sorted(
tuple(doc_grps.items()),
key=lambda x: -len(x[1]),
)
return doc_grps
def format_docs_v1(docs):
"""Simple double new line join"""
return "\n\n".join([doc.page_content for doc in docs])
def format_docs_v2(docs):
"""Format with snippet_num, legis_id, and title"""
def format_doc(idoc, doc):
return "snippet_num: {}\nlegis_id: {}\ntitle: {}\n... {} ...\n".format(
idoc,
doc.metadata["legis_id"],
doc.metadata["title"],
doc.page_content,
)
snips = []
for idoc, doc in enumerate(docs):
txt = format_doc(idoc, doc)
snips.append(txt)
return "\n===\n".join(snips)
def format_docs_v3(docs):
def format_header(doc):
return "legis_id: {}\ntitle: {}".format(
doc.metadata["legis_id"],
doc.metadata["title"],
)
def format_content(doc):
return "... {} ...\n".format(
doc.page_content,
)
snips = []
doc_grps = group_docs(docs)
for legis_id, doc_grp in doc_grps:
first_doc = doc_grp[0]
head = format_header(first_doc)
contents = []
for idoc, doc in enumerate(doc_grp):
txt = format_content(doc)
contents.append(txt)
snips.append("{}\n\n{}".format(head, "\n".join(contents)))
return "\n===\n".join(snips)
def format_docs_v4(docs):
"""JSON grouped"""
doc_grps = group_docs(docs)
out = []
for legis_id, doc_grp in doc_grps:
dd = {
"legis_id": doc_grp[0].metadata["legis_id"],
"title": doc_grp[0].metadata["title"],
"snippets": [doc.page_content for doc in doc_grp],
}
out.append(dd)
return json.dumps(out, indent=4)
DOC_FORMATTERS = {
"v1": format_docs_v1,
"v2": format_docs_v2,
"v3": format_docs_v3,
"v4": format_docs_v4,
}
def escape_markdown(text):
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
for char in MD_SPECIAL_CHARS:
text = text.replace(char, "\\" + char)
return text
with st.sidebar:
with st.container(border=True):
write_outreach_links()
st.checkbox("escape markdown in answer", key="response_escape_markdown")
with st.expander("Generative Config"):
st.selectbox(label="model name", options=OPENAI_CHAT_MODELS, key="model_name")
st.slider(
"temperature", min_value=0.0, max_value=2.0, value=0.0, key="temperature"
)
st.slider("top_p", min_value=0.0, max_value=1.0, value=1.0, key="top_p")
with st.expander("Retrieval Config"):
st.slider(
"Number of chunks to retrieve",
min_value=1,
max_value=40,
value=10,
key="n_ret_docs",
)
st.text_input("Bill ID (e.g. 118-s-2293)", key="filter_legis_id")
st.text_input("Bioguide ID (e.g. R000595)", key="filter_bioguide_id")
st.text_input("Congress (e.g. 118)", key="filter_congress_num")
with st.expander("Prompt Config"):
st.selectbox(
label="prompt version",
options=PROMPT_TEMPLATES.keys(),
index=3,
key="prompt_version",
)
st.text_area(
"prompt template",
PROMPT_TEMPLATES[SS["prompt_version"]],
height=300,
key="prompt_template",
)
llm = ChatOpenAI(
model_name=SS["model_name"],
temperature=SS["temperature"],
openai_api_key=st.secrets["openai_api_key"],
model_kwargs={"top_p": SS["top_p"], "seed": SEED},
)
vectorstore = load_pinecone_vectorstore()
format_docs = DOC_FORMATTERS[SS["prompt_version"]]
with st.form("my_form"):
st.text_area("Enter question:", key="query")
query_submitted = st.form_submit_button("Submit")
def get_vectorstore_filter():
vs_filter = {}
if SS["filter_legis_id"] != "":
vs_filter["legis_id"] = SS["filter_legis_id"]
if SS["filter_bioguide_id"] != "":
vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"]
if SS["filter_congress_num"] != "":
vs_filter["congress_num"] = int(SS["filter_congress_num"])
return vs_filter
if query_submitted:
vs_filter = get_vectorstore_filter()
retriever = vectorstore.as_retriever(
search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
)
prompt = PromptTemplate.from_template(SS["prompt_template"])
rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
| prompt
| llm
| StrOutputParser()
)
rag_chain_with_source = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)
out = rag_chain_with_source.invoke(SS["query"])
SS["out"] = out
def write_doc_grp(legis_id: str, doc_grp: list[Document]):
first_doc = doc_grp[0]
congress_gov_url = get_congress_gov_url(
first_doc.metadata["congress_num"],
first_doc.metadata["legis_type"],
first_doc.metadata["legis_num"],
)
congress_gov_link = f"[congress.gov]({congress_gov_url})"
gov_track_url = get_govtrack_url(
first_doc.metadata["congress_num"],
first_doc.metadata["legis_type"],
first_doc.metadata["legis_num"],
)
gov_track_link = f"[govtrack.us]({gov_track_url})"
ref = "{} chunks from {}\n\n{}\n\n{} | {}\n\n[{} ({}) ]({})".format(
len(doc_grp),
first_doc.metadata["legis_id"],
first_doc.metadata["title"],
congress_gov_link,
gov_track_link,
first_doc.metadata["sponsor_full_name"],
first_doc.metadata["sponsor_bioguide_id"],
get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
)
doc_contents = [
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
for doc in doc_grp
]
with st.expander(ref):
st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
out = SS.get("out")
if out:
if SS["response_escape_markdown"]:
st.info(escape_markdown(out["answer"]))
else:
st.info(out["answer"])
doc_grps = group_docs(out["context"])
for legis_id, doc_grp in doc_grps:
write_doc_grp(legis_id, doc_grp)
with st.expander("Debug doc format"):
st.text_area("formatted docs", value=format_docs(out["context"]), height=600)
# st.write(json.loads(format_docs(out["context"])))