Spaces:
Running
Running
Commit
·
76cbdff
1
Parent(s):
793d0f2
more files
Browse files- app.py +27 -231
- doc_format_mod.py +102 -0
- guide_mod.py +22 -0
- retriever_tools.py +0 -79
- sidebar_mod.py +20 -0
- usage.py → usage_mod.py +18 -18
- utils_mod.py +47 -0
- vectorstore_mod.py +46 -0
app.py
CHANGED
@@ -10,15 +10,17 @@ from langchain_core.documents import Document
|
|
10 |
from langchain_core.prompts import ChatPromptTemplate
|
11 |
from langchain_core.runnables import RunnableParallel
|
12 |
from langchain_core.runnables import RunnablePassthrough
|
13 |
-
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
14 |
-
from langchain_community.vectorstores.utils import DistanceStrategy
|
15 |
from langchain_openai import ChatOpenAI
|
16 |
from langchain_anthropic import ChatAnthropic
|
17 |
from langchain_together import ChatTogether
|
18 |
-
from langchain_pinecone import PineconeVectorStore
|
19 |
import streamlit as st
|
20 |
|
21 |
-
import
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
st.set_page_config(layout="wide", page_title="LegisQA")
|
@@ -32,16 +34,7 @@ SS = st.session_state
|
|
32 |
SEED = 292764
|
33 |
CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118]
|
34 |
SPONSOR_PARTIES = ["D", "R", "L", "I"]
|
35 |
-
|
36 |
-
"hconres": "house-concurrent-resolution",
|
37 |
-
"hjres": "house-joint-resolution",
|
38 |
-
"hr": "house-bill",
|
39 |
-
"hres": "house-resolution",
|
40 |
-
"s": "senate-bill",
|
41 |
-
"sconres": "senate-concurrent-resolution",
|
42 |
-
"sjres": "senate-joint-resolution",
|
43 |
-
"sres": "senate-resolution",
|
44 |
-
}
|
45 |
OPENAI_CHAT_MODELS = {
|
46 |
"gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}},
|
47 |
"gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}},
|
@@ -68,190 +61,6 @@ PROVIDER_MODELS = {
|
|
68 |
}
|
69 |
|
70 |
|
71 |
-
def get_sponsor_url(bioguide_id: str) -> str:
|
72 |
-
return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
|
73 |
-
|
74 |
-
|
75 |
-
def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str:
|
76 |
-
lt = CONGRESS_GOV_TYPE_MAP[legis_type]
|
77 |
-
return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
|
78 |
-
|
79 |
-
|
80 |
-
def load_bge_embeddings():
|
81 |
-
model_name = "BAAI/bge-small-en-v1.5"
|
82 |
-
model_kwargs = {"device": "cpu"}
|
83 |
-
encode_kwargs = {"normalize_embeddings": True}
|
84 |
-
emb_fn = HuggingFaceBgeEmbeddings(
|
85 |
-
model_name=model_name,
|
86 |
-
model_kwargs=model_kwargs,
|
87 |
-
encode_kwargs=encode_kwargs,
|
88 |
-
query_instruction="Represent this question for searching relevant passages: ",
|
89 |
-
)
|
90 |
-
return emb_fn
|
91 |
-
|
92 |
-
|
93 |
-
def load_pinecone_vectorstore():
|
94 |
-
emb_fn = load_bge_embeddings()
|
95 |
-
vectorstore = PineconeVectorStore(
|
96 |
-
embedding=emb_fn,
|
97 |
-
text_key="text",
|
98 |
-
distance_strategy=DistanceStrategy.COSINE,
|
99 |
-
pinecone_api_key=st.secrets["pinecone_api_key"],
|
100 |
-
index_name=st.secrets["pinecone_index_name"],
|
101 |
-
)
|
102 |
-
return vectorstore
|
103 |
-
|
104 |
-
|
105 |
-
def render_outreach_links():
|
106 |
-
nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy"
|
107 |
-
nomic_map_name = "us-congressional-legislation-s1024o256nomic-1"
|
108 |
-
nomic_url = f"{nomic_base_url}/{nomic_map_name}/map"
|
109 |
-
hf_url = "https://huggingface.co/hyperdemocracy"
|
110 |
-
pc_url = "https://www.pinecone.io/blog/serverless"
|
111 |
-
together_url = "https://www.together.ai/"
|
112 |
-
st.subheader(":brain: About [hyperdemocracy](https://hyperdemocracy.us)")
|
113 |
-
st.subheader(f":world_map: Visualize [nomic atlas]({nomic_url})")
|
114 |
-
st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})")
|
115 |
-
st.subheader(f":evergreen_tree: Index [pinecone serverless]({pc_url})")
|
116 |
-
st.subheader(f":pancakes: Inference [together.ai]({together_url})")
|
117 |
-
|
118 |
-
|
119 |
-
def render_sidebar():
|
120 |
-
|
121 |
-
with st.container(border=True):
|
122 |
-
render_outreach_links()
|
123 |
-
|
124 |
-
|
125 |
-
def group_docs(docs) -> list[tuple[str, list[Document]]]:
|
126 |
-
doc_grps = defaultdict(list)
|
127 |
-
|
128 |
-
# create legis_id groups
|
129 |
-
for doc in docs:
|
130 |
-
doc_grps[doc.metadata["legis_id"]].append(doc)
|
131 |
-
|
132 |
-
# sort docs in each group by start index
|
133 |
-
for legis_id in doc_grps.keys():
|
134 |
-
doc_grps[legis_id] = sorted(
|
135 |
-
doc_grps[legis_id],
|
136 |
-
key=lambda x: x.metadata["start_index"],
|
137 |
-
)
|
138 |
-
|
139 |
-
# sort groups by number of docs
|
140 |
-
doc_grps = sorted(
|
141 |
-
tuple(doc_grps.items()),
|
142 |
-
key=lambda x: -len(x[1]),
|
143 |
-
)
|
144 |
-
|
145 |
-
return doc_grps
|
146 |
-
|
147 |
-
|
148 |
-
def format_docs(docs: list[Document]) -> str:
|
149 |
-
"""JSON grouped"""
|
150 |
-
|
151 |
-
doc_grps = group_docs(docs)
|
152 |
-
out = []
|
153 |
-
for legis_id, doc_grp in doc_grps:
|
154 |
-
dd = {
|
155 |
-
"legis_id": doc_grp[0].metadata["legis_id"],
|
156 |
-
"title": doc_grp[0].metadata["title"],
|
157 |
-
"introduced_date": doc_grp[0].metadata["introduced_date"],
|
158 |
-
"sponsor": doc_grp[0].metadata["sponsor_full_name"],
|
159 |
-
"snippets": [doc.page_content for doc in doc_grp],
|
160 |
-
}
|
161 |
-
out.append(dd)
|
162 |
-
return json.dumps(out, indent=4)
|
163 |
-
|
164 |
-
|
165 |
-
def escape_markdown(text: str) -> str:
|
166 |
-
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
|
167 |
-
for char in MD_SPECIAL_CHARS:
|
168 |
-
text = text.replace(char, "\\" + char)
|
169 |
-
return text
|
170 |
-
|
171 |
-
|
172 |
-
def get_vectorstore_filter(ret_config: dict) -> dict:
|
173 |
-
vs_filter = {}
|
174 |
-
if ret_config["filter_legis_id"] != "":
|
175 |
-
vs_filter["legis_id"] = ret_config["filter_legis_id"]
|
176 |
-
if ret_config["filter_bioguide_id"] != "":
|
177 |
-
vs_filter["sponsor_bioguide_id"] = ret_config["filter_bioguide_id"]
|
178 |
-
vs_filter = {
|
179 |
-
**vs_filter,
|
180 |
-
"congress_num": {"$in": ret_config["filter_congress_nums"]},
|
181 |
-
}
|
182 |
-
vs_filter = {
|
183 |
-
**vs_filter,
|
184 |
-
"sponsor_party": {"$in": ret_config["filter_sponsor_parties"]},
|
185 |
-
}
|
186 |
-
return vs_filter
|
187 |
-
|
188 |
-
|
189 |
-
def render_doc_grp(legis_id: str, doc_grp: list[Document]):
|
190 |
-
first_doc = doc_grp[0]
|
191 |
-
|
192 |
-
congress_gov_url = get_congress_gov_url(
|
193 |
-
first_doc.metadata["congress_num"],
|
194 |
-
first_doc.metadata["legis_type"],
|
195 |
-
first_doc.metadata["legis_num"],
|
196 |
-
)
|
197 |
-
congress_gov_link = f"[congress.gov]({congress_gov_url})"
|
198 |
-
|
199 |
-
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
|
200 |
-
len(doc_grp),
|
201 |
-
first_doc.metadata["legis_id"],
|
202 |
-
first_doc.metadata["title"],
|
203 |
-
congress_gov_link,
|
204 |
-
first_doc.metadata["sponsor_full_name"],
|
205 |
-
first_doc.metadata["sponsor_bioguide_id"],
|
206 |
-
get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
|
207 |
-
)
|
208 |
-
doc_contents = [
|
209 |
-
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
|
210 |
-
for doc in doc_grp
|
211 |
-
]
|
212 |
-
with st.expander(ref):
|
213 |
-
st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
|
214 |
-
|
215 |
-
|
216 |
-
def legis_id_to_link(legis_id: str) -> str:
|
217 |
-
congress_num, legis_type, legis_num = legis_id.split("-")
|
218 |
-
return get_congress_gov_url(congress_num, legis_type, legis_num)
|
219 |
-
|
220 |
-
|
221 |
-
def legis_id_match_to_link(matchobj):
|
222 |
-
mstring = matchobj.string[matchobj.start() : matchobj.end()]
|
223 |
-
url = legis_id_to_link(mstring)
|
224 |
-
link = f"[{mstring}]({url})"
|
225 |
-
return link
|
226 |
-
|
227 |
-
|
228 |
-
def replace_legis_ids_with_urls(text):
|
229 |
-
pattern = "11[345678]-[a-z]+-\d{1,5}"
|
230 |
-
rtext = re.sub(pattern, legis_id_match_to_link, text)
|
231 |
-
return rtext
|
232 |
-
|
233 |
-
|
234 |
-
def render_guide():
|
235 |
-
|
236 |
-
st.write(
|
237 |
-
"""
|
238 |
-
When you send a query to LegisQA, it will attempt to retrieve relevant content from the past six congresses ([113th-118th](https://en.wikipedia.org/wiki/List_of_United_States_Congresses)) covering 2013 to the present, pass it to a [large language model (LLM)](https://en.wikipedia.org/wiki/Large_language_model), and generate a response. This technique is known as Retrieval Augmented Generation (RAG). You can read [an academic paper](https://proceedings.neurips.cc/paper/2020/hash/6b493230205f780e1bc26945df7481e5-Abstract.html) or [a high level summary](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) to get more details. Once the response is generated, the retrieved content will be available for inspection with links to the bills and sponsors.
|
239 |
-
|
240 |
-
|
241 |
-
## Disclaimer
|
242 |
-
|
243 |
-
This is a research project. The RAG technique helps to ground the LLM response by providing context from a trusted source, but it does not guarantee a high quality response. We encourage you to play around, find questions that work and find questions that fail. There is a small monthly budget dedicated to the OpenAI endpoints. Once that is used up each month, queries will no longer work.
|
244 |
-
|
245 |
-
|
246 |
-
## Config
|
247 |
-
|
248 |
-
Use the `Generative Config` to change LLM parameters.
|
249 |
-
Use the `Retrieval Config` to change the number of chunks retrieved from our congress corpus and to apply various filters to the content before it is retrieved (e.g. filter to a specific set of congresses). Use the `Prompt Config` to try out different document formatting and prompting strategies.
|
250 |
-
|
251 |
-
"""
|
252 |
-
)
|
253 |
-
|
254 |
-
|
255 |
def render_example_queries():
|
256 |
|
257 |
with st.expander("Example Queries"):
|
@@ -413,7 +222,7 @@ def get_llm(gen_config: dict):
|
|
413 |
|
414 |
|
415 |
def create_rag_chain(llm, retriever):
|
416 |
-
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
|
417 |
|
418 |
---
|
419 |
|
@@ -438,7 +247,7 @@ Query: {query}"""
|
|
438 |
"query": RunnablePassthrough(),
|
439 |
}
|
440 |
)
|
441 |
-
.assign(context=lambda x: format_docs(x["docs"]))
|
442 |
.assign(aimessage=prompt | llm)
|
443 |
)
|
444 |
|
@@ -446,9 +255,9 @@ Query: {query}"""
|
|
446 |
|
447 |
|
448 |
def process_query(gen_config: dict, ret_config: dict, query: str):
|
449 |
-
vectorstore = load_pinecone_vectorstore()
|
450 |
llm = get_llm(gen_config)
|
451 |
-
vs_filter = get_vectorstore_filter(ret_config)
|
452 |
retriever = vectorstore.as_retriever(
|
453 |
search_kwargs={"k": ret_config["n_ret_docs"], "filter": vs_filter},
|
454 |
)
|
@@ -457,44 +266,31 @@ def process_query(gen_config: dict, ret_config: dict, query: str):
|
|
457 |
return response
|
458 |
|
459 |
|
460 |
-
def
|
461 |
-
|
462 |
-
doc_grps = group_docs(docs)
|
463 |
-
if tag is None:
|
464 |
-
st.write(
|
465 |
-
"Retrieved Chunks\n\nleft click to expand, right click to follow links"
|
466 |
-
)
|
467 |
-
else:
|
468 |
-
st.write(
|
469 |
-
f"Retrieved Chunks ({tag})\n\nleft click to expand, right click to follow links"
|
470 |
-
)
|
471 |
-
for legis_id, doc_grp in doc_grps:
|
472 |
-
render_doc_grp(legis_id, doc_grp)
|
473 |
-
|
474 |
-
|
475 |
-
def display_response(
|
476 |
-
response,
|
477 |
model_info: dict,
|
478 |
provider: str,
|
479 |
should_escape_markdown: bool,
|
480 |
should_add_legis_urls: bool,
|
481 |
-
tag: str|None=None
|
482 |
):
|
483 |
-
|
484 |
if should_escape_markdown:
|
485 |
-
|
486 |
if should_add_legis_urls:
|
487 |
-
|
488 |
|
489 |
with st.container(border=True):
|
490 |
if tag is None:
|
491 |
st.write("Response")
|
492 |
else:
|
493 |
st.write(f"Response ({tag})")
|
494 |
-
st.info(
|
495 |
|
496 |
-
|
497 |
-
|
|
|
|
|
498 |
|
499 |
|
500 |
def render_query_rag_tab():
|
@@ -527,7 +323,7 @@ def render_query_rag_tab():
|
|
527 |
|
528 |
if response := SS.get(rkey):
|
529 |
model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]]
|
530 |
-
|
531 |
response,
|
532 |
model_info,
|
533 |
gen_config["provider"],
|
@@ -595,13 +391,13 @@ def render_query_rag_sbs_tab():
|
|
595 |
model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][
|
596 |
gen_configs[post_key_prefix]["model_name"]
|
597 |
]
|
598 |
-
|
599 |
response,
|
600 |
model_info,
|
601 |
gen_configs[post_key_prefix]["provider"],
|
602 |
gen_configs[post_key_prefix]["should_escape_markdown"],
|
603 |
gen_configs[post_key_prefix]["should_add_legis_urls"],
|
604 |
-
tag
|
605 |
)
|
606 |
|
607 |
|
@@ -611,7 +407,7 @@ def main():
|
|
611 |
st.header("Query Congressional Bills")
|
612 |
|
613 |
with st.sidebar:
|
614 |
-
render_sidebar()
|
615 |
|
616 |
query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
|
617 |
[
|
@@ -628,7 +424,7 @@ def main():
|
|
628 |
render_query_rag_sbs_tab()
|
629 |
|
630 |
with guide_tab:
|
631 |
-
render_guide()
|
632 |
|
633 |
|
634 |
if __name__ == "__main__":
|
|
|
10 |
from langchain_core.prompts import ChatPromptTemplate
|
11 |
from langchain_core.runnables import RunnableParallel
|
12 |
from langchain_core.runnables import RunnablePassthrough
|
|
|
|
|
13 |
from langchain_openai import ChatOpenAI
|
14 |
from langchain_anthropic import ChatAnthropic
|
15 |
from langchain_together import ChatTogether
|
|
|
16 |
import streamlit as st
|
17 |
|
18 |
+
import utils_mod
|
19 |
+
import doc_format_mod
|
20 |
+
import guide_mod
|
21 |
+
import sidebar_mod
|
22 |
+
import usage_mod
|
23 |
+
import vectorstore_mod
|
24 |
|
25 |
|
26 |
st.set_page_config(layout="wide", page_title="LegisQA")
|
|
|
34 |
SEED = 292764
|
35 |
CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118]
|
36 |
SPONSOR_PARTIES = ["D", "R", "L", "I"]
|
37 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
OPENAI_CHAT_MODELS = {
|
39 |
"gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}},
|
40 |
"gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}},
|
|
|
61 |
}
|
62 |
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def render_example_queries():
|
65 |
|
66 |
with st.expander("Example Queries"):
|
|
|
222 |
|
223 |
|
224 |
def create_rag_chain(llm, retriever):
|
225 |
+
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. When citing legis_id, use the same format as the excerpts (e.g. "116-hr-125"). If you don't know how to respond, just tell the user.
|
226 |
|
227 |
---
|
228 |
|
|
|
247 |
"query": RunnablePassthrough(),
|
248 |
}
|
249 |
)
|
250 |
+
.assign(context=lambda x: doc_format_mod.format_docs(x["docs"]))
|
251 |
.assign(aimessage=prompt | llm)
|
252 |
)
|
253 |
|
|
|
255 |
|
256 |
|
257 |
def process_query(gen_config: dict, ret_config: dict, query: str):
|
258 |
+
vectorstore = vectorstore_mod.load_pinecone_vectorstore()
|
259 |
llm = get_llm(gen_config)
|
260 |
+
vs_filter = vectorstore_mod.get_vectorstore_filter(ret_config)
|
261 |
retriever = vectorstore.as_retriever(
|
262 |
search_kwargs={"k": ret_config["n_ret_docs"], "filter": vs_filter},
|
263 |
)
|
|
|
266 |
return response
|
267 |
|
268 |
|
269 |
+
def render_response(
|
270 |
+
response: dict,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
model_info: dict,
|
272 |
provider: str,
|
273 |
should_escape_markdown: bool,
|
274 |
should_add_legis_urls: bool,
|
275 |
+
tag: str | None = None,
|
276 |
):
|
277 |
+
response_text = response["aimessage"].content
|
278 |
if should_escape_markdown:
|
279 |
+
response_text = utils_mod.escape_markdown(response_text)
|
280 |
if should_add_legis_urls:
|
281 |
+
response_text = utils_mod.replace_legis_ids_with_urls(response_text)
|
282 |
|
283 |
with st.container(border=True):
|
284 |
if tag is None:
|
285 |
st.write("Response")
|
286 |
else:
|
287 |
st.write(f"Response ({tag})")
|
288 |
+
st.info(response_text)
|
289 |
|
290 |
+
usage_mod.display_api_usage(
|
291 |
+
response["aimessage"].response_metadata, model_info, provider, tag=tag
|
292 |
+
)
|
293 |
+
doc_format_mod.render_retrieved_chunks(response["docs"], tag=tag)
|
294 |
|
295 |
|
296 |
def render_query_rag_tab():
|
|
|
323 |
|
324 |
if response := SS.get(rkey):
|
325 |
model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]]
|
326 |
+
render_response(
|
327 |
response,
|
328 |
model_info,
|
329 |
gen_config["provider"],
|
|
|
391 |
model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][
|
392 |
gen_configs[post_key_prefix]["model_name"]
|
393 |
]
|
394 |
+
render_response(
|
395 |
response,
|
396 |
model_info,
|
397 |
gen_configs[post_key_prefix]["provider"],
|
398 |
gen_configs[post_key_prefix]["should_escape_markdown"],
|
399 |
gen_configs[post_key_prefix]["should_add_legis_urls"],
|
400 |
+
tag=grp_names[post_key_prefix],
|
401 |
)
|
402 |
|
403 |
|
|
|
407 |
st.header("Query Congressional Bills")
|
408 |
|
409 |
with st.sidebar:
|
410 |
+
sidebar_mod.render_sidebar()
|
411 |
|
412 |
query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
|
413 |
[
|
|
|
424 |
render_query_rag_sbs_tab()
|
425 |
|
426 |
with guide_tab:
|
427 |
+
guide_mod.render_guide()
|
428 |
|
429 |
|
430 |
if __name__ == "__main__":
|
doc_format_mod.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import json
|
3 |
+
|
4 |
+
from langchain.schema import Document
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
import utils_mod
|
8 |
+
|
9 |
+
|
10 |
+
def group_docs(docs) -> list[tuple[str, list[Document]]]:
|
11 |
+
"""Group and sort docs.
|
12 |
+
|
13 |
+
docs are grouped by legis_id
|
14 |
+
inside a legis_id group, the docs are sorted by start_index
|
15 |
+
overall the legis_id groups are sorted by number of docs (desc)
|
16 |
+
|
17 |
+
doc_grps = [
|
18 |
+
(legis_id, start_index sorted docs), # group with the most docs
|
19 |
+
(legis_id, start_index sorted docs),
|
20 |
+
...
|
21 |
+
(legis_id, start_index sorted docs), # group with the least docs
|
22 |
+
]
|
23 |
+
"""
|
24 |
+
doc_grps = defaultdict(list)
|
25 |
+
|
26 |
+
# create legis_id groups
|
27 |
+
for doc in docs:
|
28 |
+
doc_grps[doc.metadata["legis_id"]].append(doc)
|
29 |
+
|
30 |
+
# sort docs in each group by start index
|
31 |
+
for legis_id in doc_grps.keys():
|
32 |
+
doc_grps[legis_id] = sorted(
|
33 |
+
doc_grps[legis_id],
|
34 |
+
key=lambda x: x.metadata["start_index"],
|
35 |
+
)
|
36 |
+
|
37 |
+
# sort groups by number of docs
|
38 |
+
doc_grps = sorted(
|
39 |
+
tuple(doc_grps.items()),
|
40 |
+
key=lambda x: -len(x[1]),
|
41 |
+
)
|
42 |
+
|
43 |
+
return doc_grps
|
44 |
+
|
45 |
+
|
46 |
+
def format_docs(docs: list[Document]) -> str:
|
47 |
+
"""JSON grouped"""
|
48 |
+
|
49 |
+
doc_grps = group_docs(docs)
|
50 |
+
out = []
|
51 |
+
for legis_id, doc_grp in doc_grps:
|
52 |
+
dd = {
|
53 |
+
"legis_id": doc_grp[0].metadata["legis_id"],
|
54 |
+
"title": doc_grp[0].metadata["title"],
|
55 |
+
"introduced_date": doc_grp[0].metadata["introduced_date"],
|
56 |
+
"sponsor": doc_grp[0].metadata["sponsor_full_name"],
|
57 |
+
"snippets": [doc.page_content for doc in doc_grp],
|
58 |
+
}
|
59 |
+
out.append(dd)
|
60 |
+
return json.dumps(out, indent=4)
|
61 |
+
|
62 |
+
|
63 |
+
def render_doc_grp(legis_id: str, doc_grp: list[Document]):
|
64 |
+
first_doc = doc_grp[0]
|
65 |
+
|
66 |
+
congress_gov_url = utils_mod.get_congress_gov_url(
|
67 |
+
first_doc.metadata["congress_num"],
|
68 |
+
first_doc.metadata["legis_type"],
|
69 |
+
first_doc.metadata["legis_num"],
|
70 |
+
)
|
71 |
+
congress_gov_link = f"[congress.gov]({congress_gov_url})"
|
72 |
+
|
73 |
+
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
|
74 |
+
len(doc_grp),
|
75 |
+
first_doc.metadata["legis_id"],
|
76 |
+
first_doc.metadata["title"],
|
77 |
+
congress_gov_link,
|
78 |
+
first_doc.metadata["sponsor_full_name"],
|
79 |
+
first_doc.metadata["sponsor_bioguide_id"],
|
80 |
+
utils_mod.get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
|
81 |
+
)
|
82 |
+
doc_contents = [
|
83 |
+
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
|
84 |
+
for doc in doc_grp
|
85 |
+
]
|
86 |
+
with st.expander(ref):
|
87 |
+
st.write(utils_mod.escape_markdown("\n\n...\n\n".join(doc_contents)))
|
88 |
+
|
89 |
+
|
90 |
+
def render_retrieved_chunks(docs: list[Document], tag: str | None = None):
|
91 |
+
with st.container(border=True):
|
92 |
+
doc_grps = group_docs(docs)
|
93 |
+
if tag is None:
|
94 |
+
st.write(
|
95 |
+
"Retrieved Chunks\n\nleft click to expand, right click to follow links"
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
st.write(
|
99 |
+
f"Retrieved Chunks ({tag})\n\nleft click to expand, right click to follow links"
|
100 |
+
)
|
101 |
+
for legis_id, doc_grp in doc_grps:
|
102 |
+
render_doc_grp(legis_id, doc_grp)
|
guide_mod.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def render_guide():
|
5 |
+
|
6 |
+
st.write(
|
7 |
+
"""
|
8 |
+
When you send a query to LegisQA, it will attempt to retrieve relevant content from the past six congresses ([113th-118th](https://en.wikipedia.org/wiki/List_of_United_States_Congresses)) covering 2013 to the present, pass it to a [large language model (LLM)](https://en.wikipedia.org/wiki/Large_language_model), and generate a response. This technique is known as Retrieval Augmented Generation (RAG). You can read [an academic paper](https://proceedings.neurips.cc/paper/2020/hash/6b493230205f780e1bc26945df7481e5-Abstract.html) or [a high level summary](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) to get more details. Once the response is generated, the retrieved content will be available for inspection with links to the bills and sponsors.
|
9 |
+
|
10 |
+
|
11 |
+
## Disclaimer
|
12 |
+
|
13 |
+
This is a research project. The RAG technique helps to ground the LLM response by providing context from a trusted source, but it does not guarantee a high quality response. We encourage you to play around, find questions that work and find questions that fail. There is a small monthly budget dedicated to the OpenAI endpoints. Once that is used up each month, queries will no longer work.
|
14 |
+
|
15 |
+
|
16 |
+
## Config
|
17 |
+
|
18 |
+
Use the `Generative Config` to change LLM parameters.
|
19 |
+
Use the `Retrieval Config` to change the number of chunks retrieved from our congress corpus and to apply various filters to the content before it is retrieved (e.g. filter to a specific set of congresses). Use the `Prompt Config` to try out different document formatting and prompting strategies.
|
20 |
+
|
21 |
+
"""
|
22 |
+
)
|
retriever_tools.py
DELETED
@@ -1,79 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
|
3 |
-
"""
|
4 |
-
|
5 |
-
from functools import partial
|
6 |
-
from typing import Callable
|
7 |
-
from typing import Iterable
|
8 |
-
from typing import Optional
|
9 |
-
|
10 |
-
from langchain.schema import Document
|
11 |
-
from langchain.tools import Tool
|
12 |
-
from langchain_core.callbacks.manager import Callbacks
|
13 |
-
from langchain_core.pydantic_v1 import BaseModel
|
14 |
-
from langchain_core.pydantic_v1 import Field
|
15 |
-
from langchain_core.retrievers import BaseRetriever
|
16 |
-
|
17 |
-
|
18 |
-
class RetrieverInput(BaseModel):
|
19 |
-
"""Input to the retriever."""
|
20 |
-
query: str = Field(description="query to look up in retriever")
|
21 |
-
|
22 |
-
|
23 |
-
def _get_relevant_documents(
|
24 |
-
query: str,
|
25 |
-
retriever: BaseRetriever,
|
26 |
-
format_docs: Callable[[Iterable[Document]], str],
|
27 |
-
callbacks: Callbacks = None,
|
28 |
-
) -> str:
|
29 |
-
docs = retriever.get_relevant_documents(query, callbacks=callbacks)
|
30 |
-
return format_docs(docs)
|
31 |
-
|
32 |
-
|
33 |
-
async def _aget_relevant_documents(
|
34 |
-
query: str,
|
35 |
-
retriever: BaseRetriever,
|
36 |
-
format_docs: Callable[[Iterable[Document]], str],
|
37 |
-
callbacks: Callbacks = None,
|
38 |
-
) -> str:
|
39 |
-
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
|
40 |
-
return format_docs(docs)
|
41 |
-
|
42 |
-
|
43 |
-
def get_retriever_tool(
|
44 |
-
retriever: BaseRetriever,
|
45 |
-
name: str,
|
46 |
-
description: str,
|
47 |
-
format_docs: Callable[[Iterable[Document]], str],
|
48 |
-
) -> Tool:
|
49 |
-
|
50 |
-
"""Create a tool to do retrieval of documents.
|
51 |
-
|
52 |
-
Args:
|
53 |
-
retriever: The retriever to use for the retrieval
|
54 |
-
name: The name for the tool. This will be passed to the language model,
|
55 |
-
so should be unique and somewhat descriptive.
|
56 |
-
description: The description for the tool. This will be passed to the language
|
57 |
-
model, so should be descriptive.
|
58 |
-
format_docs: A function to turn an iterable of docs into a string.
|
59 |
-
|
60 |
-
Returns:
|
61 |
-
Tool class to pass to an agent
|
62 |
-
"""
|
63 |
-
func = partial(
|
64 |
-
_get_relevant_documents,
|
65 |
-
retriever=retriever,
|
66 |
-
format_docs=format_docs,
|
67 |
-
)
|
68 |
-
afunc = partial(
|
69 |
-
_aget_relevant_documents,
|
70 |
-
retriever=retriever,
|
71 |
-
format_docs=format_docs,
|
72 |
-
)
|
73 |
-
return Tool(
|
74 |
-
name=name,
|
75 |
-
description=description,
|
76 |
-
func=func,
|
77 |
-
coroutine=afunc,
|
78 |
-
args_schema=RetrieverInput,
|
79 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sidebar_mod.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def render_outreach_links():
|
5 |
+
nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy"
|
6 |
+
nomic_map_name = "us-congressional-legislation-s1024o256nomic-1"
|
7 |
+
nomic_url = f"{nomic_base_url}/{nomic_map_name}/map"
|
8 |
+
hf_url = "https://huggingface.co/hyperdemocracy"
|
9 |
+
pc_url = "https://www.pinecone.io/blog/serverless"
|
10 |
+
together_url = "https://www.together.ai/"
|
11 |
+
st.subheader(":brain: About [hyperdemocracy](https://hyperdemocracy.us)")
|
12 |
+
st.subheader(f":world_map: Visualize [nomic atlas]({nomic_url})")
|
13 |
+
st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})")
|
14 |
+
st.subheader(f":evergreen_tree: Index [pinecone serverless]({pc_url})")
|
15 |
+
st.subheader(f":pancakes: Inference [together.ai]({together_url})")
|
16 |
+
|
17 |
+
|
18 |
+
def render_sidebar():
|
19 |
+
with st.container(border=True):
|
20 |
+
render_outreach_links()
|
usage.py → usage_mod.py
RENAMED
@@ -1,9 +1,9 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
|
4 |
-
def get_openai_token_usage(
|
5 |
-
input_tokens =
|
6 |
-
output_tokens =
|
7 |
cost = (
|
8 |
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
9 |
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
@@ -15,9 +15,9 @@ def get_openai_token_usage(metadata: dict, model_info: dict):
|
|
15 |
}
|
16 |
|
17 |
|
18 |
-
def get_anthropic_token_usage(
|
19 |
-
input_tokens =
|
20 |
-
output_tokens =
|
21 |
cost = (
|
22 |
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
23 |
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
@@ -29,9 +29,9 @@ def get_anthropic_token_usage(metadata: dict, model_info: dict):
|
|
29 |
}
|
30 |
|
31 |
|
32 |
-
def get_together_token_usage(
|
33 |
-
input_tokens =
|
34 |
-
output_tokens =
|
35 |
cost = (
|
36 |
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
37 |
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
@@ -43,27 +43,27 @@ def get_together_token_usage(metadata: dict, model_info: dict):
|
|
43 |
}
|
44 |
|
45 |
|
46 |
-
def get_token_usage(
|
47 |
match provider:
|
48 |
case "OpenAI":
|
49 |
-
return get_openai_token_usage(
|
50 |
case "Anthropic":
|
51 |
-
return get_anthropic_token_usage(
|
52 |
case "Together":
|
53 |
-
return get_together_token_usage(
|
54 |
case _:
|
55 |
raise ValueError()
|
56 |
|
57 |
|
58 |
-
def display_api_usage(
|
|
|
|
|
59 |
with st.container(border=True):
|
60 |
if tag is None:
|
61 |
st.write("API Usage")
|
62 |
else:
|
63 |
st.write(f"API Usage ({tag})")
|
64 |
-
token_usage = get_token_usage(
|
65 |
-
response["aimessage"].response_metadata, model_info, provider
|
66 |
-
)
|
67 |
col1, col2, col3 = st.columns(3)
|
68 |
with col1:
|
69 |
st.metric("Input Tokens", token_usage["input_tokens"])
|
@@ -72,4 +72,4 @@ def display_api_usage(response, model_info, provider: str, tag: str|None=None):
|
|
72 |
with col3:
|
73 |
st.metric("Cost", f"${token_usage['cost']:.4f}")
|
74 |
with st.expander("Response Metadata"):
|
75 |
-
st.warning(
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
|
4 |
+
def get_openai_token_usage(response_metadata: dict, model_info: dict):
|
5 |
+
input_tokens = response_metadata["token_usage"]["prompt_tokens"]
|
6 |
+
output_tokens = response_metadata["token_usage"]["completion_tokens"]
|
7 |
cost = (
|
8 |
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
9 |
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
|
|
15 |
}
|
16 |
|
17 |
|
18 |
+
def get_anthropic_token_usage(response_metadata: dict, model_info: dict):
|
19 |
+
input_tokens = response_metadata["usage"]["input_tokens"]
|
20 |
+
output_tokens = response_metadata["usage"]["output_tokens"]
|
21 |
cost = (
|
22 |
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
23 |
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
|
|
29 |
}
|
30 |
|
31 |
|
32 |
+
def get_together_token_usage(response_metadata: dict, model_info: dict):
|
33 |
+
input_tokens = response_metadata["token_usage"]["prompt_tokens"]
|
34 |
+
output_tokens = response_metadata["token_usage"]["completion_tokens"]
|
35 |
cost = (
|
36 |
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
37 |
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
|
|
43 |
}
|
44 |
|
45 |
|
46 |
+
def get_token_usage(response_metadata: dict, model_info: dict, provider: str):
|
47 |
match provider:
|
48 |
case "OpenAI":
|
49 |
+
return get_openai_token_usage(response_metadata, model_info)
|
50 |
case "Anthropic":
|
51 |
+
return get_anthropic_token_usage(response_metadata, model_info)
|
52 |
case "Together":
|
53 |
+
return get_together_token_usage(response_metadata, model_info)
|
54 |
case _:
|
55 |
raise ValueError()
|
56 |
|
57 |
|
58 |
+
def display_api_usage(
|
59 |
+
response_metadata: dict, model_info: dict, provider: str, tag: str | None = None
|
60 |
+
):
|
61 |
with st.container(border=True):
|
62 |
if tag is None:
|
63 |
st.write("API Usage")
|
64 |
else:
|
65 |
st.write(f"API Usage ({tag})")
|
66 |
+
token_usage = get_token_usage(response_metadata, model_info, provider)
|
|
|
|
|
67 |
col1, col2, col3 = st.columns(3)
|
68 |
with col1:
|
69 |
st.metric("Input Tokens", token_usage["input_tokens"])
|
|
|
72 |
with col3:
|
73 |
st.metric("Cost", f"${token_usage['cost']:.4f}")
|
74 |
with st.expander("Response Metadata"):
|
75 |
+
st.warning(response_metadata)
|
utils_mod.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
CONGRESS_GOV_TYPE_MAP = {
|
5 |
+
"hconres": "house-concurrent-resolution",
|
6 |
+
"hjres": "house-joint-resolution",
|
7 |
+
"hr": "house-bill",
|
8 |
+
"hres": "house-resolution",
|
9 |
+
"s": "senate-bill",
|
10 |
+
"sconres": "senate-concurrent-resolution",
|
11 |
+
"sjres": "senate-joint-resolution",
|
12 |
+
"sres": "senate-resolution",
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
def escape_markdown(text: str) -> str:
|
17 |
+
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
|
18 |
+
for char in MD_SPECIAL_CHARS:
|
19 |
+
text = text.replace(char, "\\" + char)
|
20 |
+
return text
|
21 |
+
|
22 |
+
|
23 |
+
def get_sponsor_url(bioguide_id: str) -> str:
|
24 |
+
return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
|
25 |
+
|
26 |
+
|
27 |
+
def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str:
|
28 |
+
lt = CONGRESS_GOV_TYPE_MAP[legis_type]
|
29 |
+
return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
|
30 |
+
|
31 |
+
|
32 |
+
def legis_id_to_link(legis_id: str) -> str:
|
33 |
+
congress_num, legis_type, legis_num = legis_id.split("-")
|
34 |
+
return get_congress_gov_url(congress_num, legis_type, legis_num)
|
35 |
+
|
36 |
+
|
37 |
+
def legis_id_match_to_link(matchobj):
|
38 |
+
mstring = matchobj.string[matchobj.start() : matchobj.end()]
|
39 |
+
url = legis_id_to_link(mstring)
|
40 |
+
link = f"[{mstring}]({url})"
|
41 |
+
return link
|
42 |
+
|
43 |
+
|
44 |
+
def replace_legis_ids_with_urls(text: str) -> str:
|
45 |
+
pattern = "11[345678]-[a-z]+-\d{1,5}"
|
46 |
+
rtext = re.sub(pattern, legis_id_match_to_link, text)
|
47 |
+
return rtext
|
vectorstore_mod.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
3 |
+
from langchain_pinecone import PineconeVectorStore
|
4 |
+
from langchain_community.vectorstores.utils import DistanceStrategy
|
5 |
+
|
6 |
+
|
7 |
+
def load_bge_embeddings():
|
8 |
+
model_name = "BAAI/bge-small-en-v1.5"
|
9 |
+
model_kwargs = {"device": "cpu"}
|
10 |
+
encode_kwargs = {"normalize_embeddings": True}
|
11 |
+
emb_fn = HuggingFaceBgeEmbeddings(
|
12 |
+
model_name=model_name,
|
13 |
+
model_kwargs=model_kwargs,
|
14 |
+
encode_kwargs=encode_kwargs,
|
15 |
+
query_instruction="Represent this question for searching relevant passages: ",
|
16 |
+
)
|
17 |
+
return emb_fn
|
18 |
+
|
19 |
+
|
20 |
+
def load_pinecone_vectorstore():
|
21 |
+
emb_fn = load_bge_embeddings()
|
22 |
+
vectorstore = PineconeVectorStore(
|
23 |
+
embedding=emb_fn,
|
24 |
+
text_key="text",
|
25 |
+
distance_strategy=DistanceStrategy.COSINE,
|
26 |
+
pinecone_api_key=st.secrets["pinecone_api_key"],
|
27 |
+
index_name=st.secrets["pinecone_index_name"],
|
28 |
+
)
|
29 |
+
return vectorstore
|
30 |
+
|
31 |
+
|
32 |
+
def get_vectorstore_filter(ret_config: dict) -> dict:
|
33 |
+
vs_filter = {}
|
34 |
+
if ret_config["filter_legis_id"] != "":
|
35 |
+
vs_filter["legis_id"] = ret_config["filter_legis_id"]
|
36 |
+
if ret_config["filter_bioguide_id"] != "":
|
37 |
+
vs_filter["sponsor_bioguide_id"] = ret_config["filter_bioguide_id"]
|
38 |
+
vs_filter = {
|
39 |
+
**vs_filter,
|
40 |
+
"congress_num": {"$in": ret_config["filter_congress_nums"]},
|
41 |
+
}
|
42 |
+
vs_filter = {
|
43 |
+
**vs_filter,
|
44 |
+
"sponsor_party": {"$in": ret_config["filter_sponsor_parties"]},
|
45 |
+
}
|
46 |
+
return vs_filter
|