gabrielaltay commited on
Commit
a01d550
1 Parent(s): cb6c0bd
Files changed (3) hide show
  1. app.py +12 -6
  2. custom_tools.py +0 -98
  3. retriever_tools.py +79 -0
app.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from collections import defaultdict
2
  import json
3
  import os
@@ -140,6 +145,7 @@ def format_docs(docs):
140
  dd = {
141
  "legis_id": doc_grp[0].metadata["legis_id"],
142
  "title": doc_grp[0].metadata["title"],
 
143
  "sponsor": doc_grp[0].metadata["sponsor_full_name"],
144
  "snippets": [doc.page_content for doc in doc_grp],
145
  }
@@ -308,7 +314,7 @@ def render_query_rag_tab():
308
 
309
  render_example_queries()
310
 
311
- QUERY_TEMPLATE = """Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "title", "legis_id", and "sponsor" in the response. If you don't know how to respond, just tell the user.
312
 
313
  ---
314
 
@@ -328,7 +334,7 @@ Query: {query}"""
328
  )
329
 
330
  with st.form("query_form"):
331
- st.text_area("Enter query:", key="query")
332
  query_submitted = st.form_submit_button("Submit")
333
 
334
  if query_submitted:
@@ -354,6 +360,7 @@ Query: {query}"""
354
  SS["out"] = rag_chain.invoke(SS["query"])
355
  SS["cb"] = cb
356
  else:
 
357
  SS["out"] = rag_chain.invoke(SS["query"])
358
 
359
  if "out" in SS:
@@ -386,7 +393,7 @@ Query: {query}"""
386
 
387
  def render_query_agent_tab():
388
 
389
- from custom_tools import get_retriever_tool
390
 
391
  from langchain_community.tools import WikipediaQueryRun
392
  from langchain_community.utilities import WikipediaAPIWrapper
@@ -465,9 +472,8 @@ def render_chat_agent_tab():
465
  ##################
466
 
467
 
468
- st.title(
469
- ":classical_building: LegisQA - Chat With Congressional Bills :classical_building:"
470
- )
471
 
472
 
473
  with st.sidebar:
 
1
+ """
2
+ TODO: checkout langgraph
3
+ TODO: clear screen between agent calls (see here https://github.com/langchain-ai/streamlit-agent/blob/main/streamlit_agent/clear_results.py)
4
+ """
5
+
6
  from collections import defaultdict
7
  import json
8
  import os
 
145
  dd = {
146
  "legis_id": doc_grp[0].metadata["legis_id"],
147
  "title": doc_grp[0].metadata["title"],
148
+ "introduced_date": doc_grp[0].metadata["introduced_date"],
149
  "sponsor": doc_grp[0].metadata["sponsor_full_name"],
150
  "snippets": [doc.page_content for doc in doc_grp],
151
  }
 
314
 
315
  render_example_queries()
316
 
317
+ QUERY_TEMPLATE = """Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
318
 
319
  ---
320
 
 
334
  )
335
 
336
  with st.form("query_form"):
337
+ st.text_area("Enter a query that can be answered with congressional legislation:", key="query")
338
  query_submitted = st.form_submit_button("Submit")
339
 
340
  if query_submitted:
 
360
  SS["out"] = rag_chain.invoke(SS["query"])
361
  SS["cb"] = cb
362
  else:
363
+ SS.pop("cb", None)
364
  SS["out"] = rag_chain.invoke(SS["query"])
365
 
366
  if "out" in SS:
 
393
 
394
  def render_query_agent_tab():
395
 
396
+ from retriever_tools import get_retriever_tool
397
 
398
  from langchain_community.tools import WikipediaQueryRun
399
  from langchain_community.utilities import WikipediaAPIWrapper
 
472
  ##################
473
 
474
 
475
+ st.title(":classical_building: LegisQA :classical_building:")
476
+ st.header("Chat With Congressional Bills")
 
477
 
478
 
479
  with st.sidebar:
custom_tools.py DELETED
@@ -1,98 +0,0 @@
1
- """
2
- TODO clean all this up
3
- modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
4
- """
5
-
6
- from functools import partial
7
- from typing import Optional
8
-
9
- from langchain_core.callbacks.manager import Callbacks
10
- from langchain_core.prompts import BasePromptTemplate, PromptTemplate
11
- from langchain_core.pydantic_v1 import BaseModel, Field
12
- from langchain_core.retrievers import BaseRetriever
13
- from langchain.tools import Tool
14
-
15
-
16
- def get_retriever_tool(
17
- retriever,
18
- name,
19
- description,
20
- format_docs,
21
- *,
22
- document_prompt: Optional[BasePromptTemplate] = None,
23
- document_separator: str = "\n\n",
24
- ):
25
-
26
- class RetrieverInput(BaseModel):
27
- """Input to the retriever."""
28
-
29
- query: str = Field(description="query to look up in retriever")
30
-
31
-
32
- def _get_relevant_documents(
33
- query: str,
34
- retriever: BaseRetriever,
35
- document_prompt: BasePromptTemplate,
36
- document_separator: str,
37
- callbacks: Callbacks = None,
38
- ) -> str:
39
- docs = retriever.get_relevant_documents(query, callbacks=callbacks)
40
- return format_docs(docs)
41
-
42
- async def _aget_relevant_documents(
43
- query: str,
44
- retriever: BaseRetriever,
45
- document_prompt: BasePromptTemplate,
46
- document_separator: str,
47
- callbacks: Callbacks = None,
48
- ) -> str:
49
- docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
50
- return format_docs(docs)
51
-
52
- def create_retriever_tool(
53
- retriever: BaseRetriever,
54
- name: str,
55
- description: str,
56
- *,
57
- document_prompt: Optional[BasePromptTemplate] = None,
58
- document_separator: str = "\n\n",
59
- ) -> Tool:
60
- """Create a tool to do retrieval of documents.
61
-
62
- Args:
63
- retriever: The retriever to use for the retrieval
64
- name: The name for the tool. This will be passed to the language model,
65
- so should be unique and somewhat descriptive.
66
- description: The description for the tool. This will be passed to the language
67
- model, so should be descriptive.
68
-
69
- Returns:
70
- Tool class to pass to an agent
71
- """
72
- document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
73
- func = partial(
74
- _get_relevant_documents,
75
- retriever=retriever,
76
- document_prompt=document_prompt,
77
- document_separator=document_separator,
78
- )
79
- afunc = partial(
80
- _aget_relevant_documents,
81
- retriever=retriever,
82
- document_prompt=document_prompt,
83
- document_separator=document_separator,
84
- )
85
- return Tool(
86
- name=name,
87
- description=description,
88
- func=func,
89
- coroutine=afunc,
90
- args_schema=RetrieverInput,
91
- )
92
-
93
-
94
- return create_retriever_tool(
95
- retriever,
96
- name,
97
- description,
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retriever_tools.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
3
+ """
4
+
5
+ from functools import partial
6
+ from typing import Callable
7
+ from typing import Iterable
8
+ from typing import Optional
9
+
10
+ from langchain.schema import Document
11
+ from langchain.tools import Tool
12
+ from langchain_core.callbacks.manager import Callbacks
13
+ from langchain_core.pydantic_v1 import BaseModel
14
+ from langchain_core.pydantic_v1 import Field
15
+ from langchain_core.retrievers import BaseRetriever
16
+
17
+
18
+ class RetrieverInput(BaseModel):
19
+ """Input to the retriever."""
20
+ query: str = Field(description="query to look up in retriever")
21
+
22
+
23
+ def _get_relevant_documents(
24
+ query: str,
25
+ retriever: BaseRetriever,
26
+ format_docs: Callable[[Iterable[Document]], str],
27
+ callbacks: Callbacks = None,
28
+ ) -> str:
29
+ docs = retriever.get_relevant_documents(query, callbacks=callbacks)
30
+ return format_docs(docs)
31
+
32
+
33
+ async def _aget_relevant_documents(
34
+ query: str,
35
+ retriever: BaseRetriever,
36
+ format_docs: Callable[[Iterable[Document]], str],
37
+ callbacks: Callbacks = None,
38
+ ) -> str:
39
+ docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
40
+ return format_docs(docs)
41
+
42
+
43
+ def get_retriever_tool(
44
+ retriever: BaseRetriever,
45
+ name: str,
46
+ description: str,
47
+ format_docs: Callable[[Iterable[Document]], str],
48
+ ) -> Tool:
49
+
50
+ """Create a tool to do retrieval of documents.
51
+
52
+ Args:
53
+ retriever: The retriever to use for the retrieval
54
+ name: The name for the tool. This will be passed to the language model,
55
+ so should be unique and somewhat descriptive.
56
+ description: The description for the tool. This will be passed to the language
57
+ model, so should be descriptive.
58
+ format_docs: A function to turn an iterable of docs into a string.
59
+
60
+ Returns:
61
+ Tool class to pass to an agent
62
+ """
63
+ func = partial(
64
+ _get_relevant_documents,
65
+ retriever=retriever,
66
+ format_docs=format_docs,
67
+ )
68
+ afunc = partial(
69
+ _aget_relevant_documents,
70
+ retriever=retriever,
71
+ format_docs=format_docs,
72
+ )
73
+ return Tool(
74
+ name=name,
75
+ description=description,
76
+ func=func,
77
+ coroutine=afunc,
78
+ args_schema=RetrieverInput,
79
+ )