gabrielaltay commited on
Commit
723ae91
1 Parent(s): 2029299
Files changed (1) hide show
  1. app.py +162 -98
app.py CHANGED
@@ -1,11 +1,13 @@
1
  from collections import defaultdict
2
  import json
 
3
 
4
  from langchain_core.documents import Document
5
  from langchain_core.prompts import PromptTemplate
6
  from langchain_core.runnables import RunnableParallel
7
  from langchain_core.runnables import RunnablePassthrough
8
  from langchain_core.output_parsers import StrOutputParser
 
9
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
10
  from langchain_community.vectorstores.utils import DistanceStrategy
11
  from langchain_openai import ChatOpenAI
@@ -19,6 +21,7 @@ SS = st.session_state
19
 
20
  SEED = 292764
21
  CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118]
 
22
  CONGRESS_GOV_TYPE_MAP = {
23
  "hconres": "house-concurrent-resolution",
24
  "hjres": "house-joint-resolution",
@@ -29,7 +32,6 @@ CONGRESS_GOV_TYPE_MAP = {
29
  "sjres": "senate-joint-resolution",
30
  "sres": "senate-resolution",
31
  }
32
-
33
  OPENAI_CHAT_MODELS = [
34
  "gpt-3.5-turbo-0125",
35
  "gpt-4-0125-preview",
@@ -115,6 +117,7 @@ def write_outreach_links():
115
  st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})")
116
  st.subheader(f":evergreen_tree: Index [pinecone serverless]({pc_url})")
117
 
 
118
  def group_docs(docs) -> list[tuple[str, list[Document]]]:
119
  doc_grps = defaultdict(list)
120
 
@@ -219,15 +222,96 @@ def escape_markdown(text):
219
  return text
220
 
221
 
222
- st.title(":classical_building: LegisQA :classical_building:")
223
- st.header("Explore Congressional Legislation")
224
- st.write(
225
- """When you send a query to LegisQA, it will attempt to retrieve relevant content from the past six congresses ([113th-118th](https://en.wikipedia.org/wiki/List_of_United_States_Congresses)) covering 2013 to the present, pass it to a [large language model (LLM)](https://en.wikipedia.org/wiki/Large_language_model), and generate a response. This technique is known as Retrieval Augmented Generation (RAG). You can read [an academic paper](https://proceedings.neurips.cc/paper/2020/hash/6b493230205f780e1bc26945df7481e5-Abstract.html) or [a high level summary](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) to get more details. Once the response is generated, the retrieved content will be available for inspection with links to the bills and sponsors.
226
- This technique helps to ground the LLM response by providing context from a trusted source, but it does not guarantee a high quality response. We encourage you to play around. Try different models. Find questions that work and find questions that fail.""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- st.header("Example Queries")
 
 
229
 
230
- st.write("""
231
  ```
232
  What are the themes around artificial intelligence?
233
  ```
@@ -239,8 +323,15 @@ Write a well cited 3 paragraph essay on food insecurity.
239
  ```
240
  Create a table summarizing the major climate change ideas with columns legis_id, title, idea.
241
  ```
242
- """
243
- )
 
 
 
 
 
 
 
244
 
245
 
246
  with st.sidebar:
@@ -249,6 +340,7 @@ with st.sidebar:
249
  write_outreach_links()
250
 
251
  st.checkbox("escape markdown in answer", key="response_escape_markdown")
 
252
 
253
  with st.expander("Generative Config"):
254
  st.selectbox(label="model name", options=OPENAI_CHAT_MODELS, key="model_name")
@@ -261,20 +353,24 @@ with st.sidebar:
261
  st.slider(
262
  "Number of chunks to retrieve",
263
  min_value=1,
264
- max_value=40,
265
- value=10,
266
  key="n_ret_docs",
267
  )
268
  st.text_input("Bill ID (e.g. 118-s-2293)", key="filter_legis_id")
269
  st.text_input("Bioguide ID (e.g. R000595)", key="filter_bioguide_id")
270
- # st.text_input("Congress (e.g. 118)", key="filter_congress_num")
271
  st.multiselect(
272
  "Congress Numbers",
273
  CONGRESS_NUMBERS,
274
  default=CONGRESS_NUMBERS,
275
  key="filter_congress_nums",
276
  )
277
-
 
 
 
 
 
278
 
279
  with st.expander("Prompt Config"):
280
  st.selectbox(
@@ -297,97 +393,65 @@ llm = ChatOpenAI(
297
  openai_api_key=st.secrets["openai_api_key"],
298
  model_kwargs={"top_p": SS["top_p"], "seed": SEED},
299
  )
300
-
301
  vectorstore = load_pinecone_vectorstore()
302
  format_docs = DOC_FORMATTERS[SS["prompt_version"]]
 
303
 
304
- with st.form("my_form"):
305
- st.text_area("Enter query:", key="query")
306
- query_submitted = st.form_submit_button("Submit")
307
 
308
-
309
- def get_vectorstore_filter():
310
- vs_filter = {}
311
- if SS["filter_legis_id"] != "":
312
- vs_filter["legis_id"] = SS["filter_legis_id"]
313
- if SS["filter_bioguide_id"] != "":
314
- vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"]
315
- # if SS["filter_congress_num"] != "":
316
- # vs_filter["congress_num"] = int(SS["filter_congress_num"])
317
- vs_filter = {"congress_num": {"$in": SS["filter_congress_nums"]}}
318
- return vs_filter
319
 
320
 
321
- if query_submitted:
322
 
323
- vs_filter = get_vectorstore_filter()
324
- with st.sidebar:
325
- with st.expander("Debug vs_filter"):
326
- st.write(vs_filter)
327
- retriever = vectorstore.as_retriever(
328
- search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
329
- )
330
- prompt = PromptTemplate.from_template(SS["prompt_template"])
331
- rag_chain_from_docs = (
332
- RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
333
- | prompt
334
- | llm
335
- | StrOutputParser()
336
- )
337
- rag_chain_with_source = RunnableParallel(
338
- {"context": retriever, "question": RunnablePassthrough()}
339
- ).assign(answer=rag_chain_from_docs)
340
- out = rag_chain_with_source.invoke(SS["query"])
341
- SS["out"] = out
342
 
 
 
 
343
 
344
- def write_doc_grp(legis_id: str, doc_grp: list[Document]):
345
- first_doc = doc_grp[0]
346
-
347
- congress_gov_url = get_congress_gov_url(
348
- first_doc.metadata["congress_num"],
349
- first_doc.metadata["legis_type"],
350
- first_doc.metadata["legis_num"],
351
- )
352
- congress_gov_link = f"[congress.gov]({congress_gov_url})"
353
-
354
- gov_track_url = get_govtrack_url(
355
- first_doc.metadata["congress_num"],
356
- first_doc.metadata["legis_type"],
357
- first_doc.metadata["legis_num"],
358
- )
359
- gov_track_link = f"[govtrack.us]({gov_track_url})"
360
-
361
- ref = "{} chunks from {}\n\n{}\n\n{} | {}\n\n[{} ({}) ]({})".format(
362
- len(doc_grp),
363
- first_doc.metadata["legis_id"],
364
- first_doc.metadata["title"],
365
- congress_gov_link,
366
- gov_track_link,
367
- first_doc.metadata["sponsor_full_name"],
368
- first_doc.metadata["sponsor_bioguide_id"],
369
- get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
370
- )
371
- doc_contents = [
372
- "[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
373
- for doc in doc_grp
374
- ]
375
- with st.expander(ref):
376
- st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
377
-
378
-
379
- out = SS.get("out")
380
- if out:
381
-
382
- if SS["response_escape_markdown"]:
383
- st.info(escape_markdown(out["answer"]))
384
- else:
385
- st.info(out["answer"])
386
-
387
- doc_grps = group_docs(out["context"])
388
- for legis_id, doc_grp in doc_grps:
389
- write_doc_grp(legis_id, doc_grp)
390
-
391
- with st.expander("Debug doc format"):
392
- st.text_area("formatted docs", value=format_docs(out["context"]), height=600)
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from collections import defaultdict
2
  import json
3
+ import re
4
 
5
  from langchain_core.documents import Document
6
  from langchain_core.prompts import PromptTemplate
7
  from langchain_core.runnables import RunnableParallel
8
  from langchain_core.runnables import RunnablePassthrough
9
  from langchain_core.output_parsers import StrOutputParser
10
+ from langchain_community.callbacks import get_openai_callback
11
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
12
  from langchain_community.vectorstores.utils import DistanceStrategy
13
  from langchain_openai import ChatOpenAI
 
21
 
22
  SEED = 292764
23
  CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118]
24
+ SPONSOR_PARTIES = ["D", "R", "L", "I"]
25
  CONGRESS_GOV_TYPE_MAP = {
26
  "hconres": "house-concurrent-resolution",
27
  "hjres": "house-joint-resolution",
 
32
  "sjres": "senate-joint-resolution",
33
  "sres": "senate-resolution",
34
  }
 
35
  OPENAI_CHAT_MODELS = [
36
  "gpt-3.5-turbo-0125",
37
  "gpt-4-0125-preview",
 
117
  st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})")
118
  st.subheader(f":evergreen_tree: Index [pinecone serverless]({pc_url})")
119
 
120
+
121
  def group_docs(docs) -> list[tuple[str, list[Document]]]:
122
  doc_grps = defaultdict(list)
123
 
 
222
  return text
223
 
224
 
225
+ def get_vectorstore_filter():
226
+ vs_filter = {}
227
+ if SS["filter_legis_id"] != "":
228
+ vs_filter["legis_id"] = SS["filter_legis_id"]
229
+ if SS["filter_bioguide_id"] != "":
230
+ vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"]
231
+ vs_filter = {**vs_filter, "congress_num": {"$in": SS["filter_congress_nums"]}}
232
+ vs_filter = {**vs_filter, "sponsor_party": {"$in": SS["filter_sponsor_parties"]}}
233
+ return vs_filter
234
+
235
+
236
+ def write_doc_grp(legis_id: str, doc_grp: list[Document]):
237
+ first_doc = doc_grp[0]
238
+
239
+ congress_gov_url = get_congress_gov_url(
240
+ first_doc.metadata["congress_num"],
241
+ first_doc.metadata["legis_type"],
242
+ first_doc.metadata["legis_num"],
243
+ )
244
+ congress_gov_link = f"[congress.gov]({congress_gov_url})"
245
+
246
+ gov_track_url = get_govtrack_url(
247
+ first_doc.metadata["congress_num"],
248
+ first_doc.metadata["legis_type"],
249
+ first_doc.metadata["legis_num"],
250
+ )
251
+ gov_track_link = f"[govtrack.us]({gov_track_url})"
252
+
253
+ ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
254
+ len(doc_grp),
255
+ first_doc.metadata["legis_id"],
256
+ first_doc.metadata["title"],
257
+ congress_gov_link,
258
+ first_doc.metadata["sponsor_full_name"],
259
+ first_doc.metadata["sponsor_bioguide_id"],
260
+ get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
261
+ )
262
+ doc_contents = [
263
+ "[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
264
+ for doc in doc_grp
265
+ ]
266
+ with st.expander(ref):
267
+ st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
268
+
269
+
270
+ def legis_id_to_link(legis_id: str) -> str:
271
+ congress_num, legis_type, legis_num = legis_id.split("-")
272
+ return get_congress_gov_url(congress_num, legis_type, legis_num)
273
+
274
+
275
+ def legis_id_match_to_link(matchobj):
276
+ mstring = matchobj.string[matchobj.start() : matchobj.end()]
277
+ url = legis_id_to_link(mstring)
278
+ link = f"[{mstring}]({url})"
279
+ return link
280
+
281
+
282
+ def replace_legis_ids_with_urls(text):
283
+ pattern = "11[345678]-[a-z]+-\d{1,5}"
284
+ rtext = re.sub(pattern, legis_id_match_to_link, text)
285
+ return rtext
286
+
287
+
288
+ def write_guide():
289
+
290
+ st.write(
291
+ """
292
+ When you send a query to LegisQA, it will attempt to retrieve relevant content from the past six congresses ([113th-118th](https://en.wikipedia.org/wiki/List_of_United_States_Congresses)) covering 2013 to the present, pass it to a [large language model (LLM)](https://en.wikipedia.org/wiki/Large_language_model), and generate a response. This technique is known as Retrieval Augmented Generation (RAG). You can read [an academic paper](https://proceedings.neurips.cc/paper/2020/hash/6b493230205f780e1bc26945df7481e5-Abstract.html) or [a high level summary](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) to get more details. Once the response is generated, the retrieved content will be available for inspection with links to the bills and sponsors.
293
+
294
+
295
+ ## Disclaimer
296
+
297
+ This is a research project. The RAG technique helps to ground the LLM response by providing context from a trusted source, but it does not guarantee a high quality response. We encourage you to play around, find questions that work and find questions that fail. There is a small monthly budget dedicated to the OpenAI endpoints. Once that is used up each month, queries will no longer work.
298
+
299
+
300
+ ## Sidebar Config
301
+
302
+ Use the `Generative Config` to change LLM parameters.
303
+ Use the `Retrieval Config` to change the number of chunks retrieved from our congress corpus and to apply various filters to the content before it is retrieved (e.g. filter to a specific set of congresses). Use the `Prompt Config` to try out different document formatting and prompting strategies.
304
+
305
+ """
306
+ )
307
+
308
+
309
+ def write_example_queries():
310
 
311
+ with st.expander("Example Queries"):
312
+ st.write(
313
+ """
314
 
 
315
  ```
316
  What are the themes around artificial intelligence?
317
  ```
 
323
  ```
324
  Create a table summarizing the major climate change ideas with columns legis_id, title, idea.
325
  ```
326
+
327
+ """
328
+ )
329
+
330
+
331
+ ##################
332
+
333
+
334
+ st.title(":classical_building: LegisQA :classical_building:")
335
 
336
 
337
  with st.sidebar:
 
340
  write_outreach_links()
341
 
342
  st.checkbox("escape markdown in answer", key="response_escape_markdown")
343
+ st.checkbox("add legis urls in answer", value=True, key="response_add_legis_urls")
344
 
345
  with st.expander("Generative Config"):
346
  st.selectbox(label="model name", options=OPENAI_CHAT_MODELS, key="model_name")
 
353
  st.slider(
354
  "Number of chunks to retrieve",
355
  min_value=1,
356
+ max_value=32,
357
+ value=8,
358
  key="n_ret_docs",
359
  )
360
  st.text_input("Bill ID (e.g. 118-s-2293)", key="filter_legis_id")
361
  st.text_input("Bioguide ID (e.g. R000595)", key="filter_bioguide_id")
 
362
  st.multiselect(
363
  "Congress Numbers",
364
  CONGRESS_NUMBERS,
365
  default=CONGRESS_NUMBERS,
366
  key="filter_congress_nums",
367
  )
368
+ st.multiselect(
369
+ "Sponsor Party",
370
+ SPONSOR_PARTIES,
371
+ default=SPONSOR_PARTIES,
372
+ key="filter_sponsor_parties",
373
+ )
374
 
375
  with st.expander("Prompt Config"):
376
  st.selectbox(
 
393
  openai_api_key=st.secrets["openai_api_key"],
394
  model_kwargs={"top_p": SS["top_p"], "seed": SEED},
395
  )
 
396
  vectorstore = load_pinecone_vectorstore()
397
  format_docs = DOC_FORMATTERS[SS["prompt_version"]]
398
+ vs_filter = get_vectorstore_filter()
399
 
400
+ query_tab, guide_tab = st.tabs(["query", "guide"])
 
 
401
 
402
+ with guide_tab:
403
+ write_guide()
 
 
 
 
 
 
 
 
 
404
 
405
 
406
+ with query_tab:
407
 
408
+ write_example_queries()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
+ with st.form("my_form"):
411
+ st.text_area("Enter query:", key="query")
412
+ query_submitted = st.form_submit_button("Submit")
413
 
414
+ if query_submitted:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
+ retriever = vectorstore.as_retriever(
417
+ search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
418
+ )
419
+ prompt = PromptTemplate.from_template(SS["prompt_template"])
420
+ rag_chain_from_docs = (
421
+ RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
422
+ | prompt
423
+ | llm
424
+ | StrOutputParser()
425
+ )
426
+ rag_chain_with_source = RunnableParallel(
427
+ {"context": retriever, "question": RunnablePassthrough()}
428
+ ).assign(answer=rag_chain_from_docs)
429
+
430
+ with get_openai_callback() as cb:
431
+ SS["out"] = rag_chain_with_source.invoke(SS["query"])
432
+ SS["cb"] = cb
433
+
434
+ if "out" in SS:
435
+
436
+ out_display = SS["out"]["answer"]
437
+ if SS["response_escape_markdown"]:
438
+ out_display = escape_markdown(out_display)
439
+ if SS["response_add_legis_urls"]:
440
+ out_display = replace_legis_ids_with_urls(out_display)
441
+ with st.container(border=True):
442
+ st.write("Response")
443
+ st.info(out_display)
444
+ with st.container(border=True):
445
+ st.write("API Usage")
446
+ st.warning(SS["cb"])
447
+
448
+ with st.container(border=True):
449
+ doc_grps = group_docs(SS["out"]["context"])
450
+ st.write(
451
+ "Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
452
+ )
453
+ for legis_id, doc_grp in doc_grps:
454
+ write_doc_grp(legis_id, doc_grp)
455
+
456
+ # with st.expander("Debug doc format"):
457
+ # st.text_area("formatted docs", value=format_docs(SS["out"]["context"]), height=600)