Spaces:
Running
Running
gabrielaltay
commited on
Commit
•
42554ac
1
Parent(s):
2b72dfd
upates
Browse files
app.py
CHANGED
@@ -6,36 +6,22 @@ import json
|
|
6 |
import os
|
7 |
import re
|
8 |
|
9 |
-
from langchain.tools.retriever import create_retriever_tool
|
10 |
-
from langchain.agents import AgentExecutor
|
11 |
-
from langchain.agents import create_openai_tools_agent
|
12 |
-
from langchain.agents.format_scratchpad.openai_tools import (
|
13 |
-
format_to_openai_tool_messages,
|
14 |
-
)
|
15 |
-
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
16 |
from langchain_core.documents import Document
|
17 |
-
from langchain_core.prompts import PromptTemplate
|
18 |
from langchain_core.prompts import ChatPromptTemplate
|
19 |
-
from langchain_core.prompts import MessagesPlaceholder
|
20 |
-
from langchain_core.messages import AIMessage
|
21 |
-
from langchain_core.messages import HumanMessage
|
22 |
from langchain_core.runnables import RunnableParallel
|
23 |
from langchain_core.runnables import RunnablePassthrough
|
24 |
-
from langchain_core.output_parsers import StrOutputParser
|
25 |
-
from langchain_community.callbacks import get_openai_callback
|
26 |
-
from langchain_community.callbacks import StreamlitCallbackHandler
|
27 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
28 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
29 |
from langchain_openai import ChatOpenAI
|
30 |
from langchain_anthropic import ChatAnthropic
|
31 |
from langchain_together import ChatTogether
|
32 |
from langchain_pinecone import PineconeVectorStore
|
33 |
-
from pinecone import Pinecone
|
34 |
import streamlit as st
|
35 |
|
|
|
36 |
|
37 |
-
st.set_page_config(layout="wide", page_title="LegisQA")
|
38 |
|
|
|
39 |
os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"]
|
40 |
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
41 |
os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"]
|
@@ -70,7 +56,9 @@ TOGETHER_CHAT_MODELS = {
|
|
70 |
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
|
71 |
"cost": {"pmi": 0.88, "pmo": 0.88}
|
72 |
},
|
73 |
-
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {
|
|
|
|
|
74 |
}
|
75 |
|
76 |
PROVIDER_MODELS = {
|
@@ -128,6 +116,12 @@ def render_outreach_links():
|
|
128 |
st.subheader(f":pancakes: Inference [together.ai]({together_url})")
|
129 |
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
def group_docs(docs) -> list[tuple[str, list[Document]]]:
|
132 |
doc_grps = defaultdict(list)
|
133 |
|
@@ -151,7 +145,7 @@ def group_docs(docs) -> list[tuple[str, list[Document]]]:
|
|
151 |
return doc_grps
|
152 |
|
153 |
|
154 |
-
def format_docs(docs):
|
155 |
"""JSON grouped"""
|
156 |
|
157 |
doc_grps = group_docs(docs)
|
@@ -168,26 +162,26 @@ def format_docs(docs):
|
|
168 |
return json.dumps(out, indent=4)
|
169 |
|
170 |
|
171 |
-
def escape_markdown(text):
|
172 |
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
|
173 |
for char in MD_SPECIAL_CHARS:
|
174 |
text = text.replace(char, "\\" + char)
|
175 |
return text
|
176 |
|
177 |
|
178 |
-
def get_vectorstore_filter(
|
179 |
vs_filter = {}
|
180 |
-
if
|
181 |
-
vs_filter["legis_id"] =
|
182 |
-
if
|
183 |
-
vs_filter["sponsor_bioguide_id"] =
|
184 |
vs_filter = {
|
185 |
**vs_filter,
|
186 |
-
"congress_num": {"$in":
|
187 |
}
|
188 |
vs_filter = {
|
189 |
**vs_filter,
|
190 |
-
"sponsor_party": {"$in":
|
191 |
}
|
192 |
return vs_filter
|
193 |
|
@@ -288,163 +282,137 @@ Suggest reforms that would benefit the Medicaid program.
|
|
288 |
)
|
289 |
|
290 |
|
291 |
-
def
|
292 |
-
|
293 |
-
|
|
|
|
|
|
|
294 |
)
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
key
|
|
|
|
|
299 |
)
|
300 |
-
|
301 |
-
|
|
|
|
|
302 |
min_value=0.0,
|
303 |
max_value=2.0,
|
304 |
-
value=0.
|
305 |
-
key=f"{key_prefix}|
|
306 |
)
|
307 |
-
|
308 |
-
|
|
|
|
|
309 |
min_value=1024,
|
310 |
max_value=2048,
|
311 |
-
key=f"{key_prefix}|
|
312 |
)
|
313 |
-
|
314 |
-
|
|
|
|
|
315 |
)
|
316 |
-
|
317 |
-
|
|
|
|
|
|
|
|
|
318 |
)
|
319 |
-
|
320 |
-
|
|
|
|
|
321 |
value=True,
|
322 |
-
key=f"{key_prefix}|
|
323 |
)
|
324 |
|
|
|
|
|
|
|
|
|
|
|
325 |
|
326 |
-
|
327 |
-
st.slider(
|
328 |
"Number of chunks to retrieve",
|
329 |
min_value=1,
|
330 |
max_value=32,
|
331 |
value=8,
|
332 |
-
key=f"{key_prefix}|
|
333 |
)
|
334 |
-
|
335 |
-
|
336 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
"Congress Numbers",
|
338 |
CONGRESS_NUMBERS,
|
339 |
default=CONGRESS_NUMBERS,
|
340 |
-
key=f"{key_prefix}|
|
341 |
)
|
342 |
-
|
|
|
|
|
343 |
"Sponsor Party",
|
344 |
SPONSOR_PARTIES,
|
345 |
default=SPONSOR_PARTIES,
|
346 |
-
key=f"{key_prefix}|
|
347 |
)
|
348 |
|
|
|
349 |
|
350 |
-
def get_llm(key_prefix: str):
|
351 |
-
|
352 |
-
if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
|
353 |
-
llm = ChatOpenAI(
|
354 |
-
model=SS[f"{key_prefix}|model_name"],
|
355 |
-
temperature=SS[f"{key_prefix}|temperature"],
|
356 |
-
api_key=st.secrets["openai_api_key"],
|
357 |
-
top_p=SS[f"{key_prefix}|top_p"],
|
358 |
-
seed=SEED,
|
359 |
-
max_tokens=SS[f"{key_prefix}|max_output_tokens"],
|
360 |
-
)
|
361 |
-
elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
|
362 |
-
llm = ChatAnthropic(
|
363 |
-
model_name=SS[f"{key_prefix}|model_name"],
|
364 |
-
temperature=SS[f"{key_prefix}|temperature"],
|
365 |
-
api_key=st.secrets["anthropic_api_key"],
|
366 |
-
top_p=SS[f"{key_prefix}|top_p"],
|
367 |
-
max_tokens_to_sample=SS[f"{key_prefix}|max_output_tokens"],
|
368 |
-
)
|
369 |
-
elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
|
370 |
-
llm = ChatTogether(
|
371 |
-
model=SS[f"{key_prefix}|model_name"],
|
372 |
-
temperature=SS[f"{key_prefix}|temperature"],
|
373 |
-
max_tokens=SS[f"{key_prefix}|max_output_tokens"],
|
374 |
-
top_p=SS[f"{key_prefix}|top_p"],
|
375 |
-
seed=SEED,
|
376 |
-
api_key=st.secrets["together_api_key"],
|
377 |
-
)
|
378 |
-
else:
|
379 |
-
raise ValueError()
|
380 |
|
381 |
-
|
382 |
|
|
|
383 |
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
return get_together_token_usage(metadata, model_info)
|
394 |
-
else:
|
395 |
-
raise ValueError()
|
396 |
-
|
397 |
-
|
398 |
-
def get_openai_token_usage(metadata: dict, model_info: dict):
|
399 |
-
input_tokens = metadata["token_usage"]["prompt_tokens"]
|
400 |
-
output_tokens = metadata["token_usage"]["completion_tokens"]
|
401 |
-
cost = (
|
402 |
-
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
403 |
-
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
404 |
-
)
|
405 |
-
return {
|
406 |
-
"input_tokens": input_tokens,
|
407 |
-
"output_tokens": output_tokens,
|
408 |
-
"cost": cost,
|
409 |
-
}
|
410 |
-
|
411 |
-
|
412 |
-
def get_anthropic_token_usage(metadata: dict, model_info: dict):
|
413 |
-
input_tokens = metadata["usage"]["input_tokens"]
|
414 |
-
output_tokens = metadata["usage"]["output_tokens"]
|
415 |
-
cost = (
|
416 |
-
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
417 |
-
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
418 |
-
)
|
419 |
-
return {
|
420 |
-
"input_tokens": input_tokens,
|
421 |
-
"output_tokens": output_tokens,
|
422 |
-
"cost": cost,
|
423 |
-
}
|
424 |
-
|
425 |
-
|
426 |
-
def get_together_token_usage(metadata: dict, model_info: dict):
|
427 |
-
input_tokens = metadata["token_usage"]["prompt_tokens"]
|
428 |
-
output_tokens = metadata["token_usage"]["completion_tokens"]
|
429 |
-
cost = (
|
430 |
-
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
431 |
-
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
432 |
-
)
|
433 |
-
return {
|
434 |
-
"input_tokens": input_tokens,
|
435 |
-
"output_tokens": output_tokens,
|
436 |
-
"cost": cost,
|
437 |
-
}
|
438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
|
440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
|
442 |
-
|
443 |
-
|
444 |
|
|
|
445 |
|
446 |
-
def render_query_rag_tab():
|
447 |
|
|
|
448 |
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
|
449 |
|
450 |
---
|
@@ -463,219 +431,191 @@ Query: {query}"""
|
|
463 |
]
|
464 |
)
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
key_prefix = "query_rag"
|
467 |
render_example_queries()
|
468 |
|
469 |
with st.form(f"{key_prefix}|query_form"):
|
470 |
-
st.text_area(
|
471 |
-
"Enter a query that can be answered with congressional legislation:"
|
472 |
-
key=f"{key_prefix}|query",
|
473 |
)
|
474 |
-
|
|
|
|
|
|
|
|
|
475 |
|
476 |
col1, col2 = st.columns(2)
|
477 |
with col1:
|
478 |
with st.expander("Generative Config"):
|
479 |
-
|
480 |
with col2:
|
481 |
with st.expander("Retrieval Config"):
|
482 |
-
|
483 |
|
|
|
484 |
if query_submitted:
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
)
|
491 |
|
492 |
-
rag_chain = (
|
493 |
-
RunnableParallel(
|
494 |
-
{
|
495 |
-
"docs": retriever, # list of docs
|
496 |
-
"query": RunnablePassthrough(), # str
|
497 |
-
}
|
498 |
-
)
|
499 |
-
.assign(context=(lambda x: format_docs(x["docs"])))
|
500 |
-
.assign(output=prompt | llm)
|
501 |
-
)
|
502 |
-
|
503 |
-
SS[f"{key_prefix}|out"] = rag_chain.invoke(SS[f"{key_prefix}|query"])
|
504 |
-
|
505 |
-
if f"{key_prefix}|out" in SS:
|
506 |
-
|
507 |
-
out_display = SS[f"{key_prefix}|out"]["output"].content
|
508 |
-
if SS[f"{key_prefix}|response_escape_markdown"]:
|
509 |
-
out_display = escape_markdown(out_display)
|
510 |
-
if SS[f"{key_prefix}|response_add_legis_urls"]:
|
511 |
-
out_display = replace_legis_ids_with_urls(out_display)
|
512 |
-
with st.container(border=True):
|
513 |
-
st.write("Response")
|
514 |
-
st.info(out_display)
|
515 |
-
|
516 |
-
with st.container(border=True):
|
517 |
-
st.write("API Usage")
|
518 |
-
token_usage = get_token_usage(
|
519 |
-
key_prefix, SS[f"{key_prefix}|out"]["output"].response_metadata
|
520 |
-
)
|
521 |
-
col1, col2, col3 = st.columns(3)
|
522 |
-
with col1:
|
523 |
-
st.metric("Input Tokens", token_usage["input_tokens"])
|
524 |
-
with col2:
|
525 |
-
st.metric("Output Tokens", token_usage["output_tokens"])
|
526 |
-
with col3:
|
527 |
-
st.metric("Cost", f"${token_usage['cost']:.4f}")
|
528 |
-
with st.expander("Response Metadata"):
|
529 |
-
st.warning(SS[f"{key_prefix}|out"]["output"].response_metadata)
|
530 |
-
|
531 |
-
with st.container(border=True):
|
532 |
-
doc_grps = group_docs(SS[f"{key_prefix}|out"]["docs"])
|
533 |
-
st.write(
|
534 |
-
"Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
|
535 |
-
)
|
536 |
-
for legis_id, doc_grp in doc_grps:
|
537 |
-
render_doc_grp(legis_id, doc_grp)
|
538 |
-
|
539 |
with st.expander("Debug"):
|
540 |
-
st.write(
|
541 |
|
542 |
|
543 |
def render_query_rag_sbs_tab():
|
544 |
-
|
545 |
-
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
|
546 |
-
|
547 |
-
---
|
548 |
-
|
549 |
-
Congressional Legislation Excerpts:
|
550 |
-
|
551 |
-
{context}
|
552 |
-
|
553 |
-
---
|
554 |
-
|
555 |
-
Query: {query}"""
|
556 |
-
|
557 |
base_key_prefix = "query_rag_sbs"
|
558 |
|
559 |
-
prompt = ChatPromptTemplate.from_messages(
|
560 |
-
[
|
561 |
-
("human", QUERY_RAG_TEMPLATE),
|
562 |
-
]
|
563 |
-
)
|
564 |
-
|
565 |
with st.form(f"{base_key_prefix}|query_form"):
|
566 |
-
st.text_area(
|
567 |
-
"Enter a query that can be answered with congressional legislation:"
|
568 |
-
key=f"{base_key_prefix}|query",
|
569 |
)
|
570 |
-
|
|
|
|
|
|
|
|
|
571 |
|
572 |
grp1a, grp2a = st.columns(2)
|
573 |
|
|
|
|
|
574 |
with grp1a:
|
575 |
st.header("Group 1")
|
576 |
key_prefix = f"{base_key_prefix}|grp1"
|
577 |
with st.expander("Generative Config"):
|
578 |
-
|
579 |
with st.expander("Retrieval Config"):
|
580 |
-
|
581 |
|
582 |
with grp2a:
|
583 |
st.header("Group 2")
|
584 |
key_prefix = f"{base_key_prefix}|grp2"
|
585 |
with st.expander("Generative Config"):
|
586 |
-
|
587 |
with st.expander("Retrieval Config"):
|
588 |
-
|
589 |
|
590 |
grp1b, grp2b = st.columns(2)
|
591 |
sbs_cols = {"grp1": grp1b, "grp2": grp2b}
|
|
|
592 |
|
593 |
for post_key_prefix in ["grp1", "grp2"]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
594 |
|
595 |
-
key_prefix = f"{base_key_prefix}|{post_key_prefix}"
|
596 |
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
if f"{key_prefix}|out" in SS:
|
619 |
-
with sbs_cols[post_key_prefix]:
|
620 |
-
out_display = SS[f"{key_prefix}|out"]["output"].content
|
621 |
-
if SS[f"{key_prefix}|response_escape_markdown"]:
|
622 |
-
out_display = escape_markdown(out_display)
|
623 |
-
if SS[f"{key_prefix}|response_add_legis_urls"]:
|
624 |
-
out_display = replace_legis_ids_with_urls(out_display)
|
625 |
-
with st.container(border=True):
|
626 |
-
st.write("Response")
|
627 |
-
st.info(out_display)
|
628 |
-
|
629 |
-
with st.container(border=True):
|
630 |
-
st.write("API Usage")
|
631 |
-
token_usage = get_token_usage(
|
632 |
-
key_prefix, SS[f"{key_prefix}|out"]["output"].response_metadata
|
633 |
-
)
|
634 |
-
col1, col2, col3 = st.columns(3)
|
635 |
-
with col1:
|
636 |
-
st.metric("Input Tokens", token_usage["input_tokens"])
|
637 |
-
with col2:
|
638 |
-
st.metric("Output Tokens", token_usage["output_tokens"])
|
639 |
-
with col3:
|
640 |
-
st.metric("Cost", f"${token_usage['cost']:.4f}")
|
641 |
-
with st.expander("Response Metadata"):
|
642 |
-
st.warning(SS[f"{key_prefix}|out"]["output"].response_metadata)
|
643 |
-
|
644 |
-
with st.container(border=True):
|
645 |
-
doc_grps = group_docs(SS[f"{key_prefix}|out"]["docs"])
|
646 |
-
st.write(
|
647 |
-
"Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
|
648 |
-
)
|
649 |
-
for legis_id, doc_grp in doc_grps:
|
650 |
-
render_doc_grp(legis_id, doc_grp)
|
651 |
-
|
652 |
-
|
653 |
-
##################
|
654 |
-
|
655 |
-
|
656 |
-
st.title(":classical_building: LegisQA :classical_building:")
|
657 |
-
st.header("Chat With Congressional Bills")
|
658 |
-
|
659 |
-
|
660 |
-
with st.sidebar:
|
661 |
-
render_sidebar()
|
662 |
-
|
663 |
-
|
664 |
-
vectorstore = load_pinecone_vectorstore()
|
665 |
-
|
666 |
-
query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
|
667 |
-
[
|
668 |
-
"RAG",
|
669 |
-
"RAG (side-by-side)",
|
670 |
-
"Guide",
|
671 |
-
]
|
672 |
-
)
|
673 |
|
674 |
-
with
|
675 |
-
|
676 |
|
677 |
-
with query_rag_sbs_tab:
|
678 |
-
render_query_rag_sbs_tab()
|
679 |
|
680 |
-
|
681 |
-
|
|
|
6 |
import os
|
7 |
import re
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from langchain_core.documents import Document
|
|
|
10 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
|
|
11 |
from langchain_core.runnables import RunnableParallel
|
12 |
from langchain_core.runnables import RunnablePassthrough
|
|
|
|
|
|
|
13 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
14 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
15 |
from langchain_openai import ChatOpenAI
|
16 |
from langchain_anthropic import ChatAnthropic
|
17 |
from langchain_together import ChatTogether
|
18 |
from langchain_pinecone import PineconeVectorStore
|
|
|
19 |
import streamlit as st
|
20 |
|
21 |
+
import usage
|
22 |
|
|
|
23 |
|
24 |
+
st.set_page_config(layout="wide", page_title="LegisQA")
|
25 |
os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"]
|
26 |
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
27 |
os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"]
|
|
|
56 |
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
|
57 |
"cost": {"pmi": 0.88, "pmo": 0.88}
|
58 |
},
|
59 |
+
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {
|
60 |
+
"cost": {"pmi": 5.00, "pmo": 5.00}
|
61 |
+
},
|
62 |
}
|
63 |
|
64 |
PROVIDER_MODELS = {
|
|
|
116 |
st.subheader(f":pancakes: Inference [together.ai]({together_url})")
|
117 |
|
118 |
|
119 |
+
def render_sidebar():
|
120 |
+
|
121 |
+
with st.container(border=True):
|
122 |
+
render_outreach_links()
|
123 |
+
|
124 |
+
|
125 |
def group_docs(docs) -> list[tuple[str, list[Document]]]:
|
126 |
doc_grps = defaultdict(list)
|
127 |
|
|
|
145 |
return doc_grps
|
146 |
|
147 |
|
148 |
+
def format_docs(docs: list[Document]) -> str:
|
149 |
"""JSON grouped"""
|
150 |
|
151 |
doc_grps = group_docs(docs)
|
|
|
162 |
return json.dumps(out, indent=4)
|
163 |
|
164 |
|
165 |
+
def escape_markdown(text: str) -> str:
|
166 |
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
|
167 |
for char in MD_SPECIAL_CHARS:
|
168 |
text = text.replace(char, "\\" + char)
|
169 |
return text
|
170 |
|
171 |
|
172 |
+
def get_vectorstore_filter(ret_config: dict) -> dict:
|
173 |
vs_filter = {}
|
174 |
+
if ret_config["filter_legis_id"] != "":
|
175 |
+
vs_filter["legis_id"] = ret_config["filter_legis_id"]
|
176 |
+
if ret_config["filter_bioguide_id"] != "":
|
177 |
+
vs_filter["sponsor_bioguide_id"] = ret_config["filter_bioguide_id"]
|
178 |
vs_filter = {
|
179 |
**vs_filter,
|
180 |
+
"congress_num": {"$in": ret_config["filter_congress_nums"]},
|
181 |
}
|
182 |
vs_filter = {
|
183 |
**vs_filter,
|
184 |
+
"sponsor_party": {"$in": ret_config["filter_sponsor_parties"]},
|
185 |
}
|
186 |
return vs_filter
|
187 |
|
|
|
282 |
)
|
283 |
|
284 |
|
285 |
+
def get_generative_config(key_prefix: str) -> dict:
|
286 |
+
output = {}
|
287 |
+
|
288 |
+
key = "provider"
|
289 |
+
output[key] = st.selectbox(
|
290 |
+
label=key, options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|{key}"
|
291 |
)
|
292 |
+
|
293 |
+
key = "model_name"
|
294 |
+
output[key] = st.selectbox(
|
295 |
+
label=key,
|
296 |
+
options=PROVIDER_MODELS[output["provider"]],
|
297 |
+
key=f"{key_prefix}|{key}",
|
298 |
)
|
299 |
+
|
300 |
+
key = "temperature"
|
301 |
+
output[key] = st.slider(
|
302 |
+
key,
|
303 |
min_value=0.0,
|
304 |
max_value=2.0,
|
305 |
+
value=0.0,
|
306 |
+
key=f"{key_prefix}|{key}",
|
307 |
)
|
308 |
+
|
309 |
+
key = "max_output_tokens"
|
310 |
+
output[key] = st.slider(
|
311 |
+
key,
|
312 |
min_value=1024,
|
313 |
max_value=2048,
|
314 |
+
key=f"{key_prefix}|{key}",
|
315 |
)
|
316 |
+
|
317 |
+
key = "top_p"
|
318 |
+
output[key] = st.slider(
|
319 |
+
key, min_value=0.0, max_value=1.0, value=0.9, key=f"{key_prefix}|{key}"
|
320 |
)
|
321 |
+
|
322 |
+
key = "should_escape_markdown"
|
323 |
+
output[key] = st.checkbox(
|
324 |
+
key,
|
325 |
+
value=False,
|
326 |
+
key=f"{key_prefix}|{key}",
|
327 |
)
|
328 |
+
|
329 |
+
key = "should_add_legis_urls"
|
330 |
+
output[key] = st.checkbox(
|
331 |
+
key,
|
332 |
value=True,
|
333 |
+
key=f"{key_prefix}|{key}",
|
334 |
)
|
335 |
|
336 |
+
return output
|
337 |
+
|
338 |
+
|
339 |
+
def get_retrieval_config(key_prefix: str) -> dict:
|
340 |
+
output = {}
|
341 |
|
342 |
+
key = "n_ret_docs"
|
343 |
+
output[key] = st.slider(
|
344 |
"Number of chunks to retrieve",
|
345 |
min_value=1,
|
346 |
max_value=32,
|
347 |
value=8,
|
348 |
+
key=f"{key_prefix}|{key}",
|
349 |
)
|
350 |
+
|
351 |
+
key = "filter_legis_id"
|
352 |
+
output[key] = st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|{key}")
|
353 |
+
|
354 |
+
key = "filter_bioguide_id"
|
355 |
+
output[key] = st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|{key}")
|
356 |
+
|
357 |
+
key = "filter_congress_nums"
|
358 |
+
output[key] = st.multiselect(
|
359 |
"Congress Numbers",
|
360 |
CONGRESS_NUMBERS,
|
361 |
default=CONGRESS_NUMBERS,
|
362 |
+
key=f"{key_prefix}|{key}",
|
363 |
)
|
364 |
+
|
365 |
+
key = "filter_sponsor_parties"
|
366 |
+
output[key] = st.multiselect(
|
367 |
"Sponsor Party",
|
368 |
SPONSOR_PARTIES,
|
369 |
default=SPONSOR_PARTIES,
|
370 |
+
key=f"{key_prefix}|{key}",
|
371 |
)
|
372 |
|
373 |
+
return output
|
374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
+
def get_llm(gen_config: dict):
|
377 |
|
378 |
+
match gen_config["provider"]:
|
379 |
|
380 |
+
case "OpenAI":
|
381 |
+
llm = ChatOpenAI(
|
382 |
+
model=gen_config["model_name"],
|
383 |
+
temperature=gen_config["temperature"],
|
384 |
+
api_key=st.secrets["openai_api_key"],
|
385 |
+
top_p=gen_config["top_p"],
|
386 |
+
seed=SEED,
|
387 |
+
max_tokens=gen_config["max_output_tokens"],
|
388 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
|
390 |
+
case "Anthropic":
|
391 |
+
llm = ChatAnthropic(
|
392 |
+
model_name=gen_config["model_name"],
|
393 |
+
temperature=gen_config["temperature"],
|
394 |
+
api_key=st.secrets["anthropic_api_key"],
|
395 |
+
top_p=gen_config["top_p"],
|
396 |
+
max_tokens_to_sample=gen_config["max_output_tokens"],
|
397 |
+
)
|
398 |
|
399 |
+
case "Together":
|
400 |
+
llm = ChatTogether(
|
401 |
+
model=gen_config["model_name"],
|
402 |
+
temperature=gen_config["temperature"],
|
403 |
+
max_tokens=gen_config["max_output_tokens"],
|
404 |
+
top_p=gen_config["top_p"],
|
405 |
+
seed=SEED,
|
406 |
+
api_key=st.secrets["together_api_key"],
|
407 |
+
)
|
408 |
|
409 |
+
case _:
|
410 |
+
raise ValueError()
|
411 |
|
412 |
+
return llm
|
413 |
|
|
|
414 |
|
415 |
+
def create_rag_chain(llm, retriever):
|
416 |
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
|
417 |
|
418 |
---
|
|
|
431 |
]
|
432 |
)
|
433 |
|
434 |
+
rag_chain = (
|
435 |
+
RunnableParallel(
|
436 |
+
{
|
437 |
+
"docs": retriever,
|
438 |
+
"query": RunnablePassthrough(),
|
439 |
+
}
|
440 |
+
)
|
441 |
+
.assign(context=lambda x: format_docs(x["docs"]))
|
442 |
+
.assign(aimessage=prompt | llm)
|
443 |
+
)
|
444 |
+
|
445 |
+
return rag_chain
|
446 |
+
|
447 |
+
|
448 |
+
def process_query(gen_config: dict, ret_config: dict, query: str):
|
449 |
+
vectorstore = load_pinecone_vectorstore()
|
450 |
+
llm = get_llm(gen_config)
|
451 |
+
vs_filter = get_vectorstore_filter(ret_config)
|
452 |
+
retriever = vectorstore.as_retriever(
|
453 |
+
search_kwargs={"k": ret_config["n_ret_docs"], "filter": vs_filter},
|
454 |
+
)
|
455 |
+
rag_chain = create_rag_chain(llm, retriever)
|
456 |
+
response = rag_chain.invoke(query)
|
457 |
+
return response
|
458 |
+
|
459 |
+
|
460 |
+
def display_retrieved_chunks(response):
|
461 |
+
with st.container(border=True):
|
462 |
+
doc_grps = group_docs(response["docs"])
|
463 |
+
st.write(
|
464 |
+
"Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
|
465 |
+
)
|
466 |
+
for legis_id, doc_grp in doc_grps:
|
467 |
+
render_doc_grp(legis_id, doc_grp)
|
468 |
+
|
469 |
+
|
470 |
+
def display_response(
|
471 |
+
response, model_info, provider, should_escape_markdown, should_add_legis_urls
|
472 |
+
):
|
473 |
+
out_display = response["aimessage"].content
|
474 |
+
if should_escape_markdown:
|
475 |
+
out_display = escape_markdown(out_display)
|
476 |
+
if should_add_legis_urls:
|
477 |
+
out_display = replace_legis_ids_with_urls(out_display)
|
478 |
+
|
479 |
+
with st.container(border=True):
|
480 |
+
st.write("Response")
|
481 |
+
st.info(out_display)
|
482 |
+
|
483 |
+
usage.display_api_usage(response, model_info, provider)
|
484 |
+
display_retrieved_chunks(response)
|
485 |
+
|
486 |
+
|
487 |
+
def render_query_rag_tab():
|
488 |
key_prefix = "query_rag"
|
489 |
render_example_queries()
|
490 |
|
491 |
with st.form(f"{key_prefix}|query_form"):
|
492 |
+
query = st.text_area(
|
493 |
+
"Enter a query that can be answered with congressional legislation:"
|
|
|
494 |
)
|
495 |
+
cols = st.columns(2)
|
496 |
+
with cols[0]:
|
497 |
+
query_submitted = st.form_submit_button("Submit")
|
498 |
+
with cols[1]:
|
499 |
+
status_placeholder = st.empty()
|
500 |
|
501 |
col1, col2 = st.columns(2)
|
502 |
with col1:
|
503 |
with st.expander("Generative Config"):
|
504 |
+
gen_config = get_generative_config(key_prefix)
|
505 |
with col2:
|
506 |
with st.expander("Retrieval Config"):
|
507 |
+
ret_config = get_retrieval_config(key_prefix)
|
508 |
|
509 |
+
rkey = f"{key_prefix}|response"
|
510 |
if query_submitted:
|
511 |
+
with status_placeholder:
|
512 |
+
with st.spinner("generating response"):
|
513 |
+
SS[rkey] = process_query(gen_config, ret_config, query)
|
514 |
+
|
515 |
+
if response := SS.get(rkey):
|
516 |
+
model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]]
|
517 |
+
display_response(
|
518 |
+
response,
|
519 |
+
model_info,
|
520 |
+
gen_config["provider"],
|
521 |
+
gen_config["should_escape_markdown"],
|
522 |
+
gen_config["should_add_legis_urls"],
|
523 |
)
|
524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
with st.expander("Debug"):
|
526 |
+
st.write(response)
|
527 |
|
528 |
|
529 |
def render_query_rag_sbs_tab():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
base_key_prefix = "query_rag_sbs"
|
531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
with st.form(f"{base_key_prefix}|query_form"):
|
533 |
+
query = st.text_area(
|
534 |
+
"Enter a query that can be answered with congressional legislation:"
|
|
|
535 |
)
|
536 |
+
cols = st.columns(2)
|
537 |
+
with cols[0]:
|
538 |
+
query_submitted = st.form_submit_button("Submit")
|
539 |
+
with cols[1]:
|
540 |
+
status_placeholder = st.empty()
|
541 |
|
542 |
grp1a, grp2a = st.columns(2)
|
543 |
|
544 |
+
gen_configs = {}
|
545 |
+
ret_configs = {}
|
546 |
with grp1a:
|
547 |
st.header("Group 1")
|
548 |
key_prefix = f"{base_key_prefix}|grp1"
|
549 |
with st.expander("Generative Config"):
|
550 |
+
gen_configs["grp1"] = get_generative_config(key_prefix)
|
551 |
with st.expander("Retrieval Config"):
|
552 |
+
ret_configs["grp1"] = get_retrieval_config(key_prefix)
|
553 |
|
554 |
with grp2a:
|
555 |
st.header("Group 2")
|
556 |
key_prefix = f"{base_key_prefix}|grp2"
|
557 |
with st.expander("Generative Config"):
|
558 |
+
gen_configs["grp2"] = get_generative_config(key_prefix)
|
559 |
with st.expander("Retrieval Config"):
|
560 |
+
ret_configs["grp2"] = get_retrieval_config(key_prefix)
|
561 |
|
562 |
grp1b, grp2b = st.columns(2)
|
563 |
sbs_cols = {"grp1": grp1b, "grp2": grp2b}
|
564 |
+
grp_names = {"grp1": "Group 1", "grp2": "Group 2"}
|
565 |
|
566 |
for post_key_prefix in ["grp1", "grp2"]:
|
567 |
+
with sbs_cols[post_key_prefix]:
|
568 |
+
key_prefix = f"{base_key_prefix}|{post_key_prefix}"
|
569 |
+
rkey = f"{key_prefix}|response"
|
570 |
+
if query_submitted:
|
571 |
+
with status_placeholder:
|
572 |
+
with st.spinner(
|
573 |
+
"generating response for {}".format(grp_names[post_key_prefix])
|
574 |
+
):
|
575 |
+
SS[rkey] = process_query(
|
576 |
+
gen_configs[post_key_prefix],
|
577 |
+
ret_configs[post_key_prefix],
|
578 |
+
query,
|
579 |
+
)
|
580 |
+
|
581 |
+
if response := SS.get(rkey):
|
582 |
+
model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][
|
583 |
+
gen_configs[post_key_prefix]["model_name"]
|
584 |
+
]
|
585 |
+
display_response(
|
586 |
+
response,
|
587 |
+
model_info,
|
588 |
+
gen_configs[post_key_prefix]["provider"],
|
589 |
+
gen_configs[post_key_prefix]["should_escape_markdown"],
|
590 |
+
gen_configs[post_key_prefix]["should_add_legis_urls"],
|
591 |
+
)
|
592 |
|
|
|
593 |
|
594 |
+
def main():
|
595 |
+
|
596 |
+
st.title(":classical_building: LegisQA :classical_building:")
|
597 |
+
st.header("Query Congressional Bills")
|
598 |
+
|
599 |
+
with st.sidebar:
|
600 |
+
render_sidebar()
|
601 |
+
|
602 |
+
query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
|
603 |
+
[
|
604 |
+
"RAG",
|
605 |
+
"RAG (side-by-side)",
|
606 |
+
"Guide",
|
607 |
+
]
|
608 |
+
)
|
609 |
+
|
610 |
+
with query_rag_tab:
|
611 |
+
render_query_rag_tab()
|
612 |
+
|
613 |
+
with query_rag_sbs_tab:
|
614 |
+
render_query_rag_sbs_tab()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
615 |
|
616 |
+
with guide_tab:
|
617 |
+
render_guide()
|
618 |
|
|
|
|
|
619 |
|
620 |
+
if __name__ == "__main__":
|
621 |
+
main()
|
usage.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def get_openai_token_usage(metadata: dict, model_info: dict):
|
5 |
+
input_tokens = metadata["token_usage"]["prompt_tokens"]
|
6 |
+
output_tokens = metadata["token_usage"]["completion_tokens"]
|
7 |
+
cost = (
|
8 |
+
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
9 |
+
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
10 |
+
)
|
11 |
+
return {
|
12 |
+
"input_tokens": input_tokens,
|
13 |
+
"output_tokens": output_tokens,
|
14 |
+
"cost": cost,
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def get_anthropic_token_usage(metadata: dict, model_info: dict):
|
19 |
+
input_tokens = metadata["usage"]["input_tokens"]
|
20 |
+
output_tokens = metadata["usage"]["output_tokens"]
|
21 |
+
cost = (
|
22 |
+
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
23 |
+
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
24 |
+
)
|
25 |
+
return {
|
26 |
+
"input_tokens": input_tokens,
|
27 |
+
"output_tokens": output_tokens,
|
28 |
+
"cost": cost,
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
def get_together_token_usage(metadata: dict, model_info: dict):
|
33 |
+
input_tokens = metadata["token_usage"]["prompt_tokens"]
|
34 |
+
output_tokens = metadata["token_usage"]["completion_tokens"]
|
35 |
+
cost = (
|
36 |
+
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
37 |
+
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
38 |
+
)
|
39 |
+
return {
|
40 |
+
"input_tokens": input_tokens,
|
41 |
+
"output_tokens": output_tokens,
|
42 |
+
"cost": cost,
|
43 |
+
}
|
44 |
+
|
45 |
+
|
46 |
+
def get_token_usage(metadata: dict, model_info: dict, provider: str):
|
47 |
+
match provider:
|
48 |
+
case "OpenAI":
|
49 |
+
return get_openai_token_usage(metadata, model_info)
|
50 |
+
case "Anthropic":
|
51 |
+
return get_anthropic_token_usage(metadata, model_info)
|
52 |
+
case "Together":
|
53 |
+
return get_together_token_usage(metadata, model_info)
|
54 |
+
case _:
|
55 |
+
raise ValueError()
|
56 |
+
|
57 |
+
|
58 |
+
def display_api_usage(response, model_info, provider: str):
|
59 |
+
with st.container(border=True):
|
60 |
+
st.write("API Usage")
|
61 |
+
token_usage = get_token_usage(
|
62 |
+
response["aimessage"].response_metadata, model_info, provider
|
63 |
+
)
|
64 |
+
col1, col2, col3 = st.columns(3)
|
65 |
+
with col1:
|
66 |
+
st.metric("Input Tokens", token_usage["input_tokens"])
|
67 |
+
with col2:
|
68 |
+
st.metric("Output Tokens", token_usage["output_tokens"])
|
69 |
+
with col3:
|
70 |
+
st.metric("Cost", f"${token_usage['cost']:.4f}")
|
71 |
+
with st.expander("Response Metadata"):
|
72 |
+
st.warning(response["aimessage"].response_metadata)
|