Spaces:
Sleeping
Sleeping
import os | |
import random | |
import time | |
from datasets import load_dataset | |
from openai import OpenAI | |
import pandas as pd | |
import streamlit as st | |
st.set_page_config(layout="wide") | |
CONGRESS_GOV_TYPE_MAP = { | |
"hconres": "house-concurrent-resolution", | |
"hjres": "house-joint-resolution", | |
"hr": "house-bill", | |
"hres": "house-resolution", | |
"s": "senate-bill", | |
"sconres": "senate-concurrent-resolution", | |
"sjres": "senate-joint-resolution", | |
"sres": "senate-resolution", | |
} | |
def get_data(): | |
dsd = load_dataset("hyperdemocracy/us-congress", "unified_v1") | |
df = pd.concat([ds.to_pandas() for ds in dsd.values()]) | |
df["text"] = df["textversions"].apply(lambda x: x[0]["text_v1"] if len(x) > 0 else "") | |
df = df[df["text"].str.len() > 0] | |
df1 = df[df["legis_id"]=="118-s-3207"] | |
return pd.concat([df1, df.sample(n=100)]) | |
def escape_markdown(text): | |
MD_SPECIAL_CHARS = "\`*_{}[]()#+-.!$" | |
for char in MD_SPECIAL_CHARS: | |
text = text.replace(char, "\\"+char) | |
return text | |
def get_sponsor_url(bioguide_id): | |
return f"https://bioguide.congress.gov/search/bio/{bioguide_id}" | |
def get_congress_gov_url(congress_num, legis_type, legis_num): | |
lt = CONGRESS_GOV_TYPE_MAP[legis_type] | |
return f"https://www.congress.gov/bill/{congress_num}th-congress/{lt}/{legis_num}" | |
def show_bill(bdict): | |
bill_url = get_congress_gov_url( | |
bdict["congress_num"], | |
bdict["legis_type"], | |
bdict["legis_num"], | |
) | |
sponsor_url = get_sponsor_url( | |
bdict["metadata"]["sponsors"][0]["bioguide_id"] | |
) | |
st.header("Metadata") | |
st.write("**Bill ID**: [{}]({})".format(bdict["legis_id"], bill_url)) | |
st.write("**Sponsor**: [{}]({})".format(bdict["metadata"]["sponsors"][0]["full_name"], sponsor_url)) | |
st.write("**Title**: {}".format(bdict["metadata"]["title"])) | |
st.write("**Introduced**: {}".format(bdict["metadata"]["introduced_date"])) | |
st.write("**Policy Area**: {}".format(bdict["metadata"]["policy_area"])) | |
st.write("**Subjects**: {}".format(bdict["metadata"]["subjects"])) | |
st.write("**Character Count**: {}".format(len(bdict["text"]))) | |
st.write("**Estimated Tokens**: {}".format(len(bdict["text"])/4)) | |
st.header("Summary") | |
if len(bdict["metadata"]["summaries"]) > 0: | |
st.write(bdict["metadata"]["summaries"][0]) | |
# st.markdown(bdict["metadata"]["summaries"][0]["text"], unsafe_allow_html=True) | |
else: | |
st.write("Not Available") | |
st.header("Text") | |
st.markdown(escape_markdown(bdict["text"])) | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
if "openai_model" not in st.session_state: | |
st.session_state["openai_model"] = "gpt-3.5-turbo-0125" | |
if "openai_api_key" not in st.session_state: | |
st.session_state["openai_api_key"] = None | |
df = get_data() | |
with st.sidebar: | |
st.header("Configuration") | |
openai_api_key = st.text_input( | |
label = "OpenAI API Key:", | |
help="Required for OpenAI Models", | |
type="password", | |
key="openai_api_key", | |
) | |
MODELS = ["gpt-3.5-turbo-0125", "gpt-4-0125-preview"] | |
st.selectbox("Model Name", MODELS, key="openai_model") | |
LEGIS_IDS = df["legis_id"].to_list() | |
st.selectbox("Legis ID", LEGIS_IDS, key="legis_id") | |
bdict = df[df["legis_id"] == st.session_state["legis_id"]].iloc[0].to_dict() | |
if st.button("Clear Messages"): | |
st.session_state["messages"] = [] | |
st.header("Debug") | |
with st.expander("Show Messages"): | |
st.write(st.session_state["messages"]) | |
with st.expander("Show Bill Dictionary"): | |
st.write(bdict) | |
system_message = { | |
"role": "system", | |
"content": "You are a helpful legislative question answering assistant. Use the following legislative text to help answer user questions.\n\n---" + bdict["text"], | |
} | |
with st.expander("Show Bill Details"): | |
with st.container(height=600): | |
show_bill(bdict) | |
for message in st.session_state["messages"]: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input("How can I help you understand this bill?"): | |
if st.session_state["openai_api_key"] is None: | |
st.warning("Enter API key to chat") | |
st.stop() | |
else: | |
client = OpenAI(api_key=openai_api_key) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
st.session_state["messages"].append({"role": "user", "content": prompt}) | |
with st.chat_message("assistant"): | |
stream = client.chat.completions.create( | |
model=st.session_state["openai_model"], | |
messages=[system_message] + [ | |
{"role": msg["role"], "content": msg["content"]} | |
for msg in st.session_state.messages | |
], | |
temperature=0.0, | |
stream=True, | |
) | |
response = st.write_stream(stream) | |
st.session_state["messages"].append({"role": "assistant", "content": response}) |