gabrielaltay commited on
Commit
76cbdff
·
1 Parent(s): 793d0f2

more files

Browse files
Files changed (8) hide show
  1. app.py +27 -231
  2. doc_format_mod.py +102 -0
  3. guide_mod.py +22 -0
  4. retriever_tools.py +0 -79
  5. sidebar_mod.py +20 -0
  6. usage.py → usage_mod.py +18 -18
  7. utils_mod.py +47 -0
  8. vectorstore_mod.py +46 -0
app.py CHANGED
@@ -10,15 +10,17 @@ from langchain_core.documents import Document
10
  from langchain_core.prompts import ChatPromptTemplate
11
  from langchain_core.runnables import RunnableParallel
12
  from langchain_core.runnables import RunnablePassthrough
13
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
14
- from langchain_community.vectorstores.utils import DistanceStrategy
15
  from langchain_openai import ChatOpenAI
16
  from langchain_anthropic import ChatAnthropic
17
  from langchain_together import ChatTogether
18
- from langchain_pinecone import PineconeVectorStore
19
  import streamlit as st
20
 
21
- import usage
 
 
 
 
 
22
 
23
 
24
  st.set_page_config(layout="wide", page_title="LegisQA")
@@ -32,16 +34,7 @@ SS = st.session_state
32
  SEED = 292764
33
  CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118]
34
  SPONSOR_PARTIES = ["D", "R", "L", "I"]
35
- CONGRESS_GOV_TYPE_MAP = {
36
- "hconres": "house-concurrent-resolution",
37
- "hjres": "house-joint-resolution",
38
- "hr": "house-bill",
39
- "hres": "house-resolution",
40
- "s": "senate-bill",
41
- "sconres": "senate-concurrent-resolution",
42
- "sjres": "senate-joint-resolution",
43
- "sres": "senate-resolution",
44
- }
45
  OPENAI_CHAT_MODELS = {
46
  "gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}},
47
  "gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}},
@@ -68,190 +61,6 @@ PROVIDER_MODELS = {
68
  }
69
 
70
 
71
- def get_sponsor_url(bioguide_id: str) -> str:
72
- return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
73
-
74
-
75
- def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str:
76
- lt = CONGRESS_GOV_TYPE_MAP[legis_type]
77
- return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
78
-
79
-
80
- def load_bge_embeddings():
81
- model_name = "BAAI/bge-small-en-v1.5"
82
- model_kwargs = {"device": "cpu"}
83
- encode_kwargs = {"normalize_embeddings": True}
84
- emb_fn = HuggingFaceBgeEmbeddings(
85
- model_name=model_name,
86
- model_kwargs=model_kwargs,
87
- encode_kwargs=encode_kwargs,
88
- query_instruction="Represent this question for searching relevant passages: ",
89
- )
90
- return emb_fn
91
-
92
-
93
- def load_pinecone_vectorstore():
94
- emb_fn = load_bge_embeddings()
95
- vectorstore = PineconeVectorStore(
96
- embedding=emb_fn,
97
- text_key="text",
98
- distance_strategy=DistanceStrategy.COSINE,
99
- pinecone_api_key=st.secrets["pinecone_api_key"],
100
- index_name=st.secrets["pinecone_index_name"],
101
- )
102
- return vectorstore
103
-
104
-
105
- def render_outreach_links():
106
- nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy"
107
- nomic_map_name = "us-congressional-legislation-s1024o256nomic-1"
108
- nomic_url = f"{nomic_base_url}/{nomic_map_name}/map"
109
- hf_url = "https://huggingface.co/hyperdemocracy"
110
- pc_url = "https://www.pinecone.io/blog/serverless"
111
- together_url = "https://www.together.ai/"
112
- st.subheader(":brain: About [hyperdemocracy](https://hyperdemocracy.us)")
113
- st.subheader(f":world_map: Visualize [nomic atlas]({nomic_url})")
114
- st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})")
115
- st.subheader(f":evergreen_tree: Index [pinecone serverless]({pc_url})")
116
- st.subheader(f":pancakes: Inference [together.ai]({together_url})")
117
-
118
-
119
- def render_sidebar():
120
-
121
- with st.container(border=True):
122
- render_outreach_links()
123
-
124
-
125
- def group_docs(docs) -> list[tuple[str, list[Document]]]:
126
- doc_grps = defaultdict(list)
127
-
128
- # create legis_id groups
129
- for doc in docs:
130
- doc_grps[doc.metadata["legis_id"]].append(doc)
131
-
132
- # sort docs in each group by start index
133
- for legis_id in doc_grps.keys():
134
- doc_grps[legis_id] = sorted(
135
- doc_grps[legis_id],
136
- key=lambda x: x.metadata["start_index"],
137
- )
138
-
139
- # sort groups by number of docs
140
- doc_grps = sorted(
141
- tuple(doc_grps.items()),
142
- key=lambda x: -len(x[1]),
143
- )
144
-
145
- return doc_grps
146
-
147
-
148
- def format_docs(docs: list[Document]) -> str:
149
- """JSON grouped"""
150
-
151
- doc_grps = group_docs(docs)
152
- out = []
153
- for legis_id, doc_grp in doc_grps:
154
- dd = {
155
- "legis_id": doc_grp[0].metadata["legis_id"],
156
- "title": doc_grp[0].metadata["title"],
157
- "introduced_date": doc_grp[0].metadata["introduced_date"],
158
- "sponsor": doc_grp[0].metadata["sponsor_full_name"],
159
- "snippets": [doc.page_content for doc in doc_grp],
160
- }
161
- out.append(dd)
162
- return json.dumps(out, indent=4)
163
-
164
-
165
- def escape_markdown(text: str) -> str:
166
- MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
167
- for char in MD_SPECIAL_CHARS:
168
- text = text.replace(char, "\\" + char)
169
- return text
170
-
171
-
172
- def get_vectorstore_filter(ret_config: dict) -> dict:
173
- vs_filter = {}
174
- if ret_config["filter_legis_id"] != "":
175
- vs_filter["legis_id"] = ret_config["filter_legis_id"]
176
- if ret_config["filter_bioguide_id"] != "":
177
- vs_filter["sponsor_bioguide_id"] = ret_config["filter_bioguide_id"]
178
- vs_filter = {
179
- **vs_filter,
180
- "congress_num": {"$in": ret_config["filter_congress_nums"]},
181
- }
182
- vs_filter = {
183
- **vs_filter,
184
- "sponsor_party": {"$in": ret_config["filter_sponsor_parties"]},
185
- }
186
- return vs_filter
187
-
188
-
189
- def render_doc_grp(legis_id: str, doc_grp: list[Document]):
190
- first_doc = doc_grp[0]
191
-
192
- congress_gov_url = get_congress_gov_url(
193
- first_doc.metadata["congress_num"],
194
- first_doc.metadata["legis_type"],
195
- first_doc.metadata["legis_num"],
196
- )
197
- congress_gov_link = f"[congress.gov]({congress_gov_url})"
198
-
199
- ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
200
- len(doc_grp),
201
- first_doc.metadata["legis_id"],
202
- first_doc.metadata["title"],
203
- congress_gov_link,
204
- first_doc.metadata["sponsor_full_name"],
205
- first_doc.metadata["sponsor_bioguide_id"],
206
- get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
207
- )
208
- doc_contents = [
209
- "[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
210
- for doc in doc_grp
211
- ]
212
- with st.expander(ref):
213
- st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
214
-
215
-
216
- def legis_id_to_link(legis_id: str) -> str:
217
- congress_num, legis_type, legis_num = legis_id.split("-")
218
- return get_congress_gov_url(congress_num, legis_type, legis_num)
219
-
220
-
221
- def legis_id_match_to_link(matchobj):
222
- mstring = matchobj.string[matchobj.start() : matchobj.end()]
223
- url = legis_id_to_link(mstring)
224
- link = f"[{mstring}]({url})"
225
- return link
226
-
227
-
228
- def replace_legis_ids_with_urls(text):
229
- pattern = "11[345678]-[a-z]+-\d{1,5}"
230
- rtext = re.sub(pattern, legis_id_match_to_link, text)
231
- return rtext
232
-
233
-
234
- def render_guide():
235
-
236
- st.write(
237
- """
238
- When you send a query to LegisQA, it will attempt to retrieve relevant content from the past six congresses ([113th-118th](https://en.wikipedia.org/wiki/List_of_United_States_Congresses)) covering 2013 to the present, pass it to a [large language model (LLM)](https://en.wikipedia.org/wiki/Large_language_model), and generate a response. This technique is known as Retrieval Augmented Generation (RAG). You can read [an academic paper](https://proceedings.neurips.cc/paper/2020/hash/6b493230205f780e1bc26945df7481e5-Abstract.html) or [a high level summary](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) to get more details. Once the response is generated, the retrieved content will be available for inspection with links to the bills and sponsors.
239
-
240
-
241
- ## Disclaimer
242
-
243
- This is a research project. The RAG technique helps to ground the LLM response by providing context from a trusted source, but it does not guarantee a high quality response. We encourage you to play around, find questions that work and find questions that fail. There is a small monthly budget dedicated to the OpenAI endpoints. Once that is used up each month, queries will no longer work.
244
-
245
-
246
- ## Config
247
-
248
- Use the `Generative Config` to change LLM parameters.
249
- Use the `Retrieval Config` to change the number of chunks retrieved from our congress corpus and to apply various filters to the content before it is retrieved (e.g. filter to a specific set of congresses). Use the `Prompt Config` to try out different document formatting and prompting strategies.
250
-
251
- """
252
- )
253
-
254
-
255
  def render_example_queries():
256
 
257
  with st.expander("Example Queries"):
@@ -413,7 +222,7 @@ def get_llm(gen_config: dict):
413
 
414
 
415
  def create_rag_chain(llm, retriever):
416
- QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
417
 
418
  ---
419
 
@@ -438,7 +247,7 @@ Query: {query}"""
438
  "query": RunnablePassthrough(),
439
  }
440
  )
441
- .assign(context=lambda x: format_docs(x["docs"]))
442
  .assign(aimessage=prompt | llm)
443
  )
444
 
@@ -446,9 +255,9 @@ Query: {query}"""
446
 
447
 
448
  def process_query(gen_config: dict, ret_config: dict, query: str):
449
- vectorstore = load_pinecone_vectorstore()
450
  llm = get_llm(gen_config)
451
- vs_filter = get_vectorstore_filter(ret_config)
452
  retriever = vectorstore.as_retriever(
453
  search_kwargs={"k": ret_config["n_ret_docs"], "filter": vs_filter},
454
  )
@@ -457,44 +266,31 @@ def process_query(gen_config: dict, ret_config: dict, query: str):
457
  return response
458
 
459
 
460
- def display_retrieved_chunks(docs: list[Document], tag: str|None=None):
461
- with st.container(border=True):
462
- doc_grps = group_docs(docs)
463
- if tag is None:
464
- st.write(
465
- "Retrieved Chunks\n\nleft click to expand, right click to follow links"
466
- )
467
- else:
468
- st.write(
469
- f"Retrieved Chunks ({tag})\n\nleft click to expand, right click to follow links"
470
- )
471
- for legis_id, doc_grp in doc_grps:
472
- render_doc_grp(legis_id, doc_grp)
473
-
474
-
475
- def display_response(
476
- response,
477
  model_info: dict,
478
  provider: str,
479
  should_escape_markdown: bool,
480
  should_add_legis_urls: bool,
481
- tag: str|None=None
482
  ):
483
- out_display = response["aimessage"].content
484
  if should_escape_markdown:
485
- out_display = escape_markdown(out_display)
486
  if should_add_legis_urls:
487
- out_display = replace_legis_ids_with_urls(out_display)
488
 
489
  with st.container(border=True):
490
  if tag is None:
491
  st.write("Response")
492
  else:
493
  st.write(f"Response ({tag})")
494
- st.info(out_display)
495
 
496
- usage.display_api_usage(response, model_info, provider, tag=tag)
497
- display_retrieved_chunks(response["docs"], tag=tag)
 
 
498
 
499
 
500
  def render_query_rag_tab():
@@ -527,7 +323,7 @@ def render_query_rag_tab():
527
 
528
  if response := SS.get(rkey):
529
  model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]]
530
- display_response(
531
  response,
532
  model_info,
533
  gen_config["provider"],
@@ -595,13 +391,13 @@ def render_query_rag_sbs_tab():
595
  model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][
596
  gen_configs[post_key_prefix]["model_name"]
597
  ]
598
- display_response(
599
  response,
600
  model_info,
601
  gen_configs[post_key_prefix]["provider"],
602
  gen_configs[post_key_prefix]["should_escape_markdown"],
603
  gen_configs[post_key_prefix]["should_add_legis_urls"],
604
- tag = grp_names[post_key_prefix],
605
  )
606
 
607
 
@@ -611,7 +407,7 @@ def main():
611
  st.header("Query Congressional Bills")
612
 
613
  with st.sidebar:
614
- render_sidebar()
615
 
616
  query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
617
  [
@@ -628,7 +424,7 @@ def main():
628
  render_query_rag_sbs_tab()
629
 
630
  with guide_tab:
631
- render_guide()
632
 
633
 
634
  if __name__ == "__main__":
 
10
  from langchain_core.prompts import ChatPromptTemplate
11
  from langchain_core.runnables import RunnableParallel
12
  from langchain_core.runnables import RunnablePassthrough
 
 
13
  from langchain_openai import ChatOpenAI
14
  from langchain_anthropic import ChatAnthropic
15
  from langchain_together import ChatTogether
 
16
  import streamlit as st
17
 
18
+ import utils_mod
19
+ import doc_format_mod
20
+ import guide_mod
21
+ import sidebar_mod
22
+ import usage_mod
23
+ import vectorstore_mod
24
 
25
 
26
  st.set_page_config(layout="wide", page_title="LegisQA")
 
34
  SEED = 292764
35
  CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118]
36
  SPONSOR_PARTIES = ["D", "R", "L", "I"]
37
+
 
 
 
 
 
 
 
 
 
38
  OPENAI_CHAT_MODELS = {
39
  "gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}},
40
  "gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}},
 
61
  }
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def render_example_queries():
65
 
66
  with st.expander("Example Queries"):
 
222
 
223
 
224
  def create_rag_chain(llm, retriever):
225
+ QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. When citing legis_id, use the same format as the excerpts (e.g. "116-hr-125"). If you don't know how to respond, just tell the user.
226
 
227
  ---
228
 
 
247
  "query": RunnablePassthrough(),
248
  }
249
  )
250
+ .assign(context=lambda x: doc_format_mod.format_docs(x["docs"]))
251
  .assign(aimessage=prompt | llm)
252
  )
253
 
 
255
 
256
 
257
  def process_query(gen_config: dict, ret_config: dict, query: str):
258
+ vectorstore = vectorstore_mod.load_pinecone_vectorstore()
259
  llm = get_llm(gen_config)
260
+ vs_filter = vectorstore_mod.get_vectorstore_filter(ret_config)
261
  retriever = vectorstore.as_retriever(
262
  search_kwargs={"k": ret_config["n_ret_docs"], "filter": vs_filter},
263
  )
 
266
  return response
267
 
268
 
269
+ def render_response(
270
+ response: dict,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  model_info: dict,
272
  provider: str,
273
  should_escape_markdown: bool,
274
  should_add_legis_urls: bool,
275
+ tag: str | None = None,
276
  ):
277
+ response_text = response["aimessage"].content
278
  if should_escape_markdown:
279
+ response_text = utils_mod.escape_markdown(response_text)
280
  if should_add_legis_urls:
281
+ response_text = utils_mod.replace_legis_ids_with_urls(response_text)
282
 
283
  with st.container(border=True):
284
  if tag is None:
285
  st.write("Response")
286
  else:
287
  st.write(f"Response ({tag})")
288
+ st.info(response_text)
289
 
290
+ usage_mod.display_api_usage(
291
+ response["aimessage"].response_metadata, model_info, provider, tag=tag
292
+ )
293
+ doc_format_mod.render_retrieved_chunks(response["docs"], tag=tag)
294
 
295
 
296
  def render_query_rag_tab():
 
323
 
324
  if response := SS.get(rkey):
325
  model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]]
326
+ render_response(
327
  response,
328
  model_info,
329
  gen_config["provider"],
 
391
  model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][
392
  gen_configs[post_key_prefix]["model_name"]
393
  ]
394
+ render_response(
395
  response,
396
  model_info,
397
  gen_configs[post_key_prefix]["provider"],
398
  gen_configs[post_key_prefix]["should_escape_markdown"],
399
  gen_configs[post_key_prefix]["should_add_legis_urls"],
400
+ tag=grp_names[post_key_prefix],
401
  )
402
 
403
 
 
407
  st.header("Query Congressional Bills")
408
 
409
  with st.sidebar:
410
+ sidebar_mod.render_sidebar()
411
 
412
  query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
413
  [
 
424
  render_query_rag_sbs_tab()
425
 
426
  with guide_tab:
427
+ guide_mod.render_guide()
428
 
429
 
430
  if __name__ == "__main__":
doc_format_mod.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import json
3
+
4
+ from langchain.schema import Document
5
+ import streamlit as st
6
+
7
+ import utils_mod
8
+
9
+
10
+ def group_docs(docs) -> list[tuple[str, list[Document]]]:
11
+ """Group and sort docs.
12
+
13
+ docs are grouped by legis_id
14
+ inside a legis_id group, the docs are sorted by start_index
15
+ overall the legis_id groups are sorted by number of docs (desc)
16
+
17
+ doc_grps = [
18
+ (legis_id, start_index sorted docs), # group with the most docs
19
+ (legis_id, start_index sorted docs),
20
+ ...
21
+ (legis_id, start_index sorted docs), # group with the least docs
22
+ ]
23
+ """
24
+ doc_grps = defaultdict(list)
25
+
26
+ # create legis_id groups
27
+ for doc in docs:
28
+ doc_grps[doc.metadata["legis_id"]].append(doc)
29
+
30
+ # sort docs in each group by start index
31
+ for legis_id in doc_grps.keys():
32
+ doc_grps[legis_id] = sorted(
33
+ doc_grps[legis_id],
34
+ key=lambda x: x.metadata["start_index"],
35
+ )
36
+
37
+ # sort groups by number of docs
38
+ doc_grps = sorted(
39
+ tuple(doc_grps.items()),
40
+ key=lambda x: -len(x[1]),
41
+ )
42
+
43
+ return doc_grps
44
+
45
+
46
+ def format_docs(docs: list[Document]) -> str:
47
+ """JSON grouped"""
48
+
49
+ doc_grps = group_docs(docs)
50
+ out = []
51
+ for legis_id, doc_grp in doc_grps:
52
+ dd = {
53
+ "legis_id": doc_grp[0].metadata["legis_id"],
54
+ "title": doc_grp[0].metadata["title"],
55
+ "introduced_date": doc_grp[0].metadata["introduced_date"],
56
+ "sponsor": doc_grp[0].metadata["sponsor_full_name"],
57
+ "snippets": [doc.page_content for doc in doc_grp],
58
+ }
59
+ out.append(dd)
60
+ return json.dumps(out, indent=4)
61
+
62
+
63
+ def render_doc_grp(legis_id: str, doc_grp: list[Document]):
64
+ first_doc = doc_grp[0]
65
+
66
+ congress_gov_url = utils_mod.get_congress_gov_url(
67
+ first_doc.metadata["congress_num"],
68
+ first_doc.metadata["legis_type"],
69
+ first_doc.metadata["legis_num"],
70
+ )
71
+ congress_gov_link = f"[congress.gov]({congress_gov_url})"
72
+
73
+ ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
74
+ len(doc_grp),
75
+ first_doc.metadata["legis_id"],
76
+ first_doc.metadata["title"],
77
+ congress_gov_link,
78
+ first_doc.metadata["sponsor_full_name"],
79
+ first_doc.metadata["sponsor_bioguide_id"],
80
+ utils_mod.get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
81
+ )
82
+ doc_contents = [
83
+ "[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
84
+ for doc in doc_grp
85
+ ]
86
+ with st.expander(ref):
87
+ st.write(utils_mod.escape_markdown("\n\n...\n\n".join(doc_contents)))
88
+
89
+
90
+ def render_retrieved_chunks(docs: list[Document], tag: str | None = None):
91
+ with st.container(border=True):
92
+ doc_grps = group_docs(docs)
93
+ if tag is None:
94
+ st.write(
95
+ "Retrieved Chunks\n\nleft click to expand, right click to follow links"
96
+ )
97
+ else:
98
+ st.write(
99
+ f"Retrieved Chunks ({tag})\n\nleft click to expand, right click to follow links"
100
+ )
101
+ for legis_id, doc_grp in doc_grps:
102
+ render_doc_grp(legis_id, doc_grp)
guide_mod.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def render_guide():
5
+
6
+ st.write(
7
+ """
8
+ When you send a query to LegisQA, it will attempt to retrieve relevant content from the past six congresses ([113th-118th](https://en.wikipedia.org/wiki/List_of_United_States_Congresses)) covering 2013 to the present, pass it to a [large language model (LLM)](https://en.wikipedia.org/wiki/Large_language_model), and generate a response. This technique is known as Retrieval Augmented Generation (RAG). You can read [an academic paper](https://proceedings.neurips.cc/paper/2020/hash/6b493230205f780e1bc26945df7481e5-Abstract.html) or [a high level summary](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) to get more details. Once the response is generated, the retrieved content will be available for inspection with links to the bills and sponsors.
9
+
10
+
11
+ ## Disclaimer
12
+
13
+ This is a research project. The RAG technique helps to ground the LLM response by providing context from a trusted source, but it does not guarantee a high quality response. We encourage you to play around, find questions that work and find questions that fail. There is a small monthly budget dedicated to the OpenAI endpoints. Once that is used up each month, queries will no longer work.
14
+
15
+
16
+ ## Config
17
+
18
+ Use the `Generative Config` to change LLM parameters.
19
+ Use the `Retrieval Config` to change the number of chunks retrieved from our congress corpus and to apply various filters to the content before it is retrieved (e.g. filter to a specific set of congresses). Use the `Prompt Config` to try out different document formatting and prompting strategies.
20
+
21
+ """
22
+ )
retriever_tools.py DELETED
@@ -1,79 +0,0 @@
1
- """
2
- modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
3
- """
4
-
5
- from functools import partial
6
- from typing import Callable
7
- from typing import Iterable
8
- from typing import Optional
9
-
10
- from langchain.schema import Document
11
- from langchain.tools import Tool
12
- from langchain_core.callbacks.manager import Callbacks
13
- from langchain_core.pydantic_v1 import BaseModel
14
- from langchain_core.pydantic_v1 import Field
15
- from langchain_core.retrievers import BaseRetriever
16
-
17
-
18
- class RetrieverInput(BaseModel):
19
- """Input to the retriever."""
20
- query: str = Field(description="query to look up in retriever")
21
-
22
-
23
- def _get_relevant_documents(
24
- query: str,
25
- retriever: BaseRetriever,
26
- format_docs: Callable[[Iterable[Document]], str],
27
- callbacks: Callbacks = None,
28
- ) -> str:
29
- docs = retriever.get_relevant_documents(query, callbacks=callbacks)
30
- return format_docs(docs)
31
-
32
-
33
- async def _aget_relevant_documents(
34
- query: str,
35
- retriever: BaseRetriever,
36
- format_docs: Callable[[Iterable[Document]], str],
37
- callbacks: Callbacks = None,
38
- ) -> str:
39
- docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
40
- return format_docs(docs)
41
-
42
-
43
- def get_retriever_tool(
44
- retriever: BaseRetriever,
45
- name: str,
46
- description: str,
47
- format_docs: Callable[[Iterable[Document]], str],
48
- ) -> Tool:
49
-
50
- """Create a tool to do retrieval of documents.
51
-
52
- Args:
53
- retriever: The retriever to use for the retrieval
54
- name: The name for the tool. This will be passed to the language model,
55
- so should be unique and somewhat descriptive.
56
- description: The description for the tool. This will be passed to the language
57
- model, so should be descriptive.
58
- format_docs: A function to turn an iterable of docs into a string.
59
-
60
- Returns:
61
- Tool class to pass to an agent
62
- """
63
- func = partial(
64
- _get_relevant_documents,
65
- retriever=retriever,
66
- format_docs=format_docs,
67
- )
68
- afunc = partial(
69
- _aget_relevant_documents,
70
- retriever=retriever,
71
- format_docs=format_docs,
72
- )
73
- return Tool(
74
- name=name,
75
- description=description,
76
- func=func,
77
- coroutine=afunc,
78
- args_schema=RetrieverInput,
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sidebar_mod.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def render_outreach_links():
5
+ nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy"
6
+ nomic_map_name = "us-congressional-legislation-s1024o256nomic-1"
7
+ nomic_url = f"{nomic_base_url}/{nomic_map_name}/map"
8
+ hf_url = "https://huggingface.co/hyperdemocracy"
9
+ pc_url = "https://www.pinecone.io/blog/serverless"
10
+ together_url = "https://www.together.ai/"
11
+ st.subheader(":brain: About [hyperdemocracy](https://hyperdemocracy.us)")
12
+ st.subheader(f":world_map: Visualize [nomic atlas]({nomic_url})")
13
+ st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})")
14
+ st.subheader(f":evergreen_tree: Index [pinecone serverless]({pc_url})")
15
+ st.subheader(f":pancakes: Inference [together.ai]({together_url})")
16
+
17
+
18
+ def render_sidebar():
19
+ with st.container(border=True):
20
+ render_outreach_links()
usage.py → usage_mod.py RENAMED
@@ -1,9 +1,9 @@
1
  import streamlit as st
2
 
3
 
4
- def get_openai_token_usage(metadata: dict, model_info: dict):
5
- input_tokens = metadata["token_usage"]["prompt_tokens"]
6
- output_tokens = metadata["token_usage"]["completion_tokens"]
7
  cost = (
8
  input_tokens * 1e-6 * model_info["cost"]["pmi"]
9
  + output_tokens * 1e-6 * model_info["cost"]["pmo"]
@@ -15,9 +15,9 @@ def get_openai_token_usage(metadata: dict, model_info: dict):
15
  }
16
 
17
 
18
- def get_anthropic_token_usage(metadata: dict, model_info: dict):
19
- input_tokens = metadata["usage"]["input_tokens"]
20
- output_tokens = metadata["usage"]["output_tokens"]
21
  cost = (
22
  input_tokens * 1e-6 * model_info["cost"]["pmi"]
23
  + output_tokens * 1e-6 * model_info["cost"]["pmo"]
@@ -29,9 +29,9 @@ def get_anthropic_token_usage(metadata: dict, model_info: dict):
29
  }
30
 
31
 
32
- def get_together_token_usage(metadata: dict, model_info: dict):
33
- input_tokens = metadata["token_usage"]["prompt_tokens"]
34
- output_tokens = metadata["token_usage"]["completion_tokens"]
35
  cost = (
36
  input_tokens * 1e-6 * model_info["cost"]["pmi"]
37
  + output_tokens * 1e-6 * model_info["cost"]["pmo"]
@@ -43,27 +43,27 @@ def get_together_token_usage(metadata: dict, model_info: dict):
43
  }
44
 
45
 
46
- def get_token_usage(metadata: dict, model_info: dict, provider: str):
47
  match provider:
48
  case "OpenAI":
49
- return get_openai_token_usage(metadata, model_info)
50
  case "Anthropic":
51
- return get_anthropic_token_usage(metadata, model_info)
52
  case "Together":
53
- return get_together_token_usage(metadata, model_info)
54
  case _:
55
  raise ValueError()
56
 
57
 
58
- def display_api_usage(response, model_info, provider: str, tag: str|None=None):
 
 
59
  with st.container(border=True):
60
  if tag is None:
61
  st.write("API Usage")
62
  else:
63
  st.write(f"API Usage ({tag})")
64
- token_usage = get_token_usage(
65
- response["aimessage"].response_metadata, model_info, provider
66
- )
67
  col1, col2, col3 = st.columns(3)
68
  with col1:
69
  st.metric("Input Tokens", token_usage["input_tokens"])
@@ -72,4 +72,4 @@ def display_api_usage(response, model_info, provider: str, tag: str|None=None):
72
  with col3:
73
  st.metric("Cost", f"${token_usage['cost']:.4f}")
74
  with st.expander("Response Metadata"):
75
- st.warning(response["aimessage"].response_metadata)
 
1
  import streamlit as st
2
 
3
 
4
+ def get_openai_token_usage(response_metadata: dict, model_info: dict):
5
+ input_tokens = response_metadata["token_usage"]["prompt_tokens"]
6
+ output_tokens = response_metadata["token_usage"]["completion_tokens"]
7
  cost = (
8
  input_tokens * 1e-6 * model_info["cost"]["pmi"]
9
  + output_tokens * 1e-6 * model_info["cost"]["pmo"]
 
15
  }
16
 
17
 
18
+ def get_anthropic_token_usage(response_metadata: dict, model_info: dict):
19
+ input_tokens = response_metadata["usage"]["input_tokens"]
20
+ output_tokens = response_metadata["usage"]["output_tokens"]
21
  cost = (
22
  input_tokens * 1e-6 * model_info["cost"]["pmi"]
23
  + output_tokens * 1e-6 * model_info["cost"]["pmo"]
 
29
  }
30
 
31
 
32
+ def get_together_token_usage(response_metadata: dict, model_info: dict):
33
+ input_tokens = response_metadata["token_usage"]["prompt_tokens"]
34
+ output_tokens = response_metadata["token_usage"]["completion_tokens"]
35
  cost = (
36
  input_tokens * 1e-6 * model_info["cost"]["pmi"]
37
  + output_tokens * 1e-6 * model_info["cost"]["pmo"]
 
43
  }
44
 
45
 
46
+ def get_token_usage(response_metadata: dict, model_info: dict, provider: str):
47
  match provider:
48
  case "OpenAI":
49
+ return get_openai_token_usage(response_metadata, model_info)
50
  case "Anthropic":
51
+ return get_anthropic_token_usage(response_metadata, model_info)
52
  case "Together":
53
+ return get_together_token_usage(response_metadata, model_info)
54
  case _:
55
  raise ValueError()
56
 
57
 
58
+ def display_api_usage(
59
+ response_metadata: dict, model_info: dict, provider: str, tag: str | None = None
60
+ ):
61
  with st.container(border=True):
62
  if tag is None:
63
  st.write("API Usage")
64
  else:
65
  st.write(f"API Usage ({tag})")
66
+ token_usage = get_token_usage(response_metadata, model_info, provider)
 
 
67
  col1, col2, col3 = st.columns(3)
68
  with col1:
69
  st.metric("Input Tokens", token_usage["input_tokens"])
 
72
  with col3:
73
  st.metric("Cost", f"${token_usage['cost']:.4f}")
74
  with st.expander("Response Metadata"):
75
+ st.warning(response_metadata)
utils_mod.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ CONGRESS_GOV_TYPE_MAP = {
5
+ "hconres": "house-concurrent-resolution",
6
+ "hjres": "house-joint-resolution",
7
+ "hr": "house-bill",
8
+ "hres": "house-resolution",
9
+ "s": "senate-bill",
10
+ "sconres": "senate-concurrent-resolution",
11
+ "sjres": "senate-joint-resolution",
12
+ "sres": "senate-resolution",
13
+ }
14
+
15
+
16
+ def escape_markdown(text: str) -> str:
17
+ MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
18
+ for char in MD_SPECIAL_CHARS:
19
+ text = text.replace(char, "\\" + char)
20
+ return text
21
+
22
+
23
+ def get_sponsor_url(bioguide_id: str) -> str:
24
+ return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
25
+
26
+
27
+ def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str:
28
+ lt = CONGRESS_GOV_TYPE_MAP[legis_type]
29
+ return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
30
+
31
+
32
+ def legis_id_to_link(legis_id: str) -> str:
33
+ congress_num, legis_type, legis_num = legis_id.split("-")
34
+ return get_congress_gov_url(congress_num, legis_type, legis_num)
35
+
36
+
37
+ def legis_id_match_to_link(matchobj):
38
+ mstring = matchobj.string[matchobj.start() : matchobj.end()]
39
+ url = legis_id_to_link(mstring)
40
+ link = f"[{mstring}]({url})"
41
+ return link
42
+
43
+
44
+ def replace_legis_ids_with_urls(text: str) -> str:
45
+ pattern = "11[345678]-[a-z]+-\d{1,5}"
46
+ rtext = re.sub(pattern, legis_id_match_to_link, text)
47
+ return rtext
vectorstore_mod.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
3
+ from langchain_pinecone import PineconeVectorStore
4
+ from langchain_community.vectorstores.utils import DistanceStrategy
5
+
6
+
7
+ def load_bge_embeddings():
8
+ model_name = "BAAI/bge-small-en-v1.5"
9
+ model_kwargs = {"device": "cpu"}
10
+ encode_kwargs = {"normalize_embeddings": True}
11
+ emb_fn = HuggingFaceBgeEmbeddings(
12
+ model_name=model_name,
13
+ model_kwargs=model_kwargs,
14
+ encode_kwargs=encode_kwargs,
15
+ query_instruction="Represent this question for searching relevant passages: ",
16
+ )
17
+ return emb_fn
18
+
19
+
20
+ def load_pinecone_vectorstore():
21
+ emb_fn = load_bge_embeddings()
22
+ vectorstore = PineconeVectorStore(
23
+ embedding=emb_fn,
24
+ text_key="text",
25
+ distance_strategy=DistanceStrategy.COSINE,
26
+ pinecone_api_key=st.secrets["pinecone_api_key"],
27
+ index_name=st.secrets["pinecone_index_name"],
28
+ )
29
+ return vectorstore
30
+
31
+
32
+ def get_vectorstore_filter(ret_config: dict) -> dict:
33
+ vs_filter = {}
34
+ if ret_config["filter_legis_id"] != "":
35
+ vs_filter["legis_id"] = ret_config["filter_legis_id"]
36
+ if ret_config["filter_bioguide_id"] != "":
37
+ vs_filter["sponsor_bioguide_id"] = ret_config["filter_bioguide_id"]
38
+ vs_filter = {
39
+ **vs_filter,
40
+ "congress_num": {"$in": ret_config["filter_congress_nums"]},
41
+ }
42
+ vs_filter = {
43
+ **vs_filter,
44
+ "sponsor_party": {"$in": ret_config["filter_sponsor_parties"]},
45
+ }
46
+ return vs_filter