Asaad Almutareb commited on
Commit
d713a77
1 Parent(s): 57e87b0

reordered the tools and adjust their descriptions

Browse files

added the new tools to the active toolkit
added fastapi wrapper for gradio

app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from hf_mixtral_agent import agent_executor
3
  from innovation_pathfinder_ai.source_container.container import (
@@ -19,6 +20,8 @@ dotenv.load_dotenv()
19
 
20
  logger = logger.get_console_logger("app")
21
 
 
 
22
  def initialize_chroma_db() -> Chroma:
23
  collection_name=os.getenv("CONVERSATION_COLLECTION_NAME")
24
 
@@ -102,8 +105,7 @@ if __name__ == "__main__":
102
  )
103
  clear.click(lambda: None, None, chatbot, queue=False)
104
 
105
- demo.queue()
106
- demo.launch(debug=True, share=True)
107
-
108
 
109
- x = 0 # for debugging purposes
 
 
1
+ from fastapi import FastAPI
2
  import gradio as gr
3
  from hf_mixtral_agent import agent_executor
4
  from innovation_pathfinder_ai.source_container.container import (
 
20
 
21
  logger = logger.get_console_logger("app")
22
 
23
+ app = FastAPI()
24
+
25
  def initialize_chroma_db() -> Chroma:
26
  collection_name=os.getenv("CONVERSATION_COLLECTION_NAME")
27
 
 
105
  )
106
  clear.click(lambda: None, None, chatbot, queue=False)
107
 
108
+ demo.queue().launch(debug=True, share=True)
 
 
109
 
110
+ x = 0 # for debugging purposes
111
+ app = gr.mount_gradio_app(app, demo, path="/")
hf_mixtral_agent.py CHANGED
@@ -8,7 +8,7 @@ from langchain.tools.render import render_text_description
8
  import os
9
  from dotenv import load_dotenv
10
  from innovation_pathfinder_ai.structured_tools.structured_tools import (
11
- arxiv_search, get_arxiv_paper, google_search, wikipedia_search
12
  )
13
 
14
  from langchain.prompts import PromptTemplate
@@ -36,6 +36,8 @@ llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
36
 
37
 
38
  tools = [
 
 
39
  arxiv_search,
40
  wikipedia_search,
41
  google_search,
 
8
  import os
9
  from dotenv import load_dotenv
10
  from innovation_pathfinder_ai.structured_tools.structured_tools import (
11
+ arxiv_search, get_arxiv_paper, google_search, wikipedia_search, knowledgeBase_search, memory_search
12
  )
13
 
14
  from langchain.prompts import PromptTemplate
 
36
 
37
 
38
  tools = [
39
+ memory_search,
40
+ knowledgeBase_search,
41
  arxiv_search,
42
  wikipedia_search,
43
  google_search,
innovation_pathfinder_ai/structured_tools/structured_tools.py CHANGED
@@ -35,65 +35,35 @@ import os
35
  # from innovation_pathfinder_ai.utils import create_wikipedia_urls_from_text
36
 
37
  @tool
38
- def arxiv_search(query: str) -> str:
39
- """Search arxiv database for scientific research papers and studies. This is your primary information source.
40
- always check it first when you search for information, before using any other tool."""
41
- global all_sources
42
- arxiv_retriever = ArxivRetriever(load_max_docs=3)
43
- data = arxiv_retriever.invoke(query)
44
- meta_data = [i.metadata for i in data]
45
- formatted_sources = format_arxiv_documents(data)
46
- all_sources += formatted_sources
47
- parsed_sources = parse_list_to_dicts(formatted_sources)
48
- add_many(parsed_sources)
49
-
50
- return data.__str__()
51
-
52
- @tool
53
- def get_arxiv_paper(paper_id:str) -> None:
54
- """Download a paper from axriv to download a paper please input
55
- the axriv id such as "1605.08386v1" This tool is named get_arxiv_paper
56
- If you input "http://arxiv.org/abs/2312.02813", This will break the code. Also only do
57
- "2312.02813". In addition please download one paper at a time. Pleaase keep the inputs/output
58
- free of additional information only have the id.
59
- """
60
- # code from https://lukasschwab.me/arxiv.py/arxiv.html
61
- paper = next(arxiv.Client().results(arxiv.Search(id_list=[paper_id])))
62
-
63
- number_without_period = paper_id.replace('.', '')
64
 
65
- # Download the PDF to a specified directory with a custom filename.
66
- paper.download_pdf(dirpath="./downloaded_papers", filename=f"{number_without_period}.pdf")
67
 
 
 
 
68
 
69
- @tool
70
- def google_search(query: str) -> str:
71
- """Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
72
- global all_sources
 
73
 
74
- websearch = GoogleSearchAPIWrapper()
75
- search_results:dict = websearch.results(query, 3)
76
- cleaner_sources =format_search_results(search_results)
77
- parsed_csources = parse_list_to_dicts(cleaner_sources)
78
- add_many(parsed_csources)
79
- all_sources += cleaner_sources
80
 
81
- return cleaner_sources.__str__()
82
-
83
- @tool
84
- def wikipedia_search(query: str) -> str:
85
- """Search Wikipedia for additional information to expand on research papers or when no papers can be found."""
86
- global all_sources
87
-
88
- api_wrapper = WikipediaAPIWrapper()
89
- wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
90
- wikipedia_results = wikipedia_search.run(query)
91
- all_sources += create_wikipedia_urls_from_text(wikipedia_results)
92
- return wikipedia_results
93
 
94
  @tool
95
- def chroma_search(query:str) -> str:
96
- """Search the Arxiv vector store for docmunets and relevent chunks"""
97
  # Since we have more than one collections we should change the name of this tool
98
  client = chromadb.PersistentClient(
99
  # path=persist_directory,
@@ -117,6 +87,36 @@ def chroma_search(query:str) -> str:
117
 
118
  return docs.__str__()
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  @tool
122
  def embed_arvix_paper(paper_id:str) -> None:
@@ -158,27 +158,26 @@ def embed_arvix_paper(paper_id:str) -> None:
158
  )
159
 
160
  @tool
161
- def conversational_search(query:str) -> str:
162
- """Search from past conversations for docmunets and relevent chunks"""
163
- # Since we have more than one collections we should change the name of this tool
164
- client = chromadb.PersistentClient(
165
- # path=persist_directory,
166
- )
167
-
168
- collection_name=os.getenv("CONVERSATION_COLLECTION_NAME")
169
- #store using envar
170
-
171
- embedding_function = SentenceTransformerEmbeddings(
172
- model_name="all-MiniLM-L6-v2",
173
- )
174
-
175
- vector_db = Chroma(
176
- client=client, # client for Chroma
177
- collection_name=collection_name,
178
- embedding_function=embedding_function,
179
- )
180
 
181
- retriever = vector_db.as_retriever()
182
- docs = retriever.get_relevant_documents(query)
 
 
 
 
183
 
184
- return docs.__str__()
 
35
  # from innovation_pathfinder_ai.utils import create_wikipedia_urls_from_text
36
 
37
  @tool
38
+ def memory_search(query:str) -> str:
39
+ """Search the memory vector store for existing knowledge and relevent pervious researches. \
40
+ This is your primary source to start your search with checking what you already have learned from the past, before going online."""
41
+ # Since we have more than one collections we should change the name of this tool
42
+ client = chromadb.PersistentClient(
43
+ # path=persist_directory,
44
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ collection_name=os.getenv("CONVERSATION_COLLECTION_NAME")
47
+ #store using envar
48
 
49
+ embedding_function = SentenceTransformerEmbeddings(
50
+ model_name="all-MiniLM-L6-v2",
51
+ )
52
 
53
+ vector_db = Chroma(
54
+ client=client, # client for Chroma
55
+ collection_name=collection_name,
56
+ embedding_function=embedding_function,
57
+ )
58
 
59
+ retriever = vector_db.as_retriever()
60
+ docs = retriever.get_relevant_documents(query)
 
 
 
 
61
 
62
+ return docs.__str__()
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  @tool
65
+ def knowledgeBase_search(query:str) -> str:
66
+ """Search the internal knowledge base for research papers and relevent chunks"""
67
  # Since we have more than one collections we should change the name of this tool
68
  client = chromadb.PersistentClient(
69
  # path=persist_directory,
 
87
 
88
  return docs.__str__()
89
 
90
+ @tool
91
+ def arxiv_search(query: str) -> str:
92
+ """Search arxiv database for scientific research papers and studies. This is your primary online information source.
93
+ always check it first when you search for additional information, before using any other online tool."""
94
+ global all_sources
95
+ arxiv_retriever = ArxivRetriever(load_max_docs=3)
96
+ data = arxiv_retriever.invoke(query)
97
+ meta_data = [i.metadata for i in data]
98
+ formatted_sources = format_arxiv_documents(data)
99
+ all_sources += formatted_sources
100
+ parsed_sources = parse_list_to_dicts(formatted_sources)
101
+ add_many(parsed_sources)
102
+
103
+ return data.__str__()
104
+
105
+ @tool
106
+ def get_arxiv_paper(paper_id:str) -> None:
107
+ """Download a paper from axriv to download a paper please input
108
+ the axriv id such as "1605.08386v1" This tool is named get_arxiv_paper
109
+ If you input "http://arxiv.org/abs/2312.02813", This will break the code. Also only do
110
+ "2312.02813". In addition please download one paper at a time. Pleaase keep the inputs/output
111
+ free of additional information only have the id.
112
+ """
113
+ # code from https://lukasschwab.me/arxiv.py/arxiv.html
114
+ paper = next(arxiv.Client().results(arxiv.Search(id_list=[paper_id])))
115
+
116
+ number_without_period = paper_id.replace('.', '')
117
+
118
+ # Download the PDF to a specified directory with a custom filename.
119
+ paper.download_pdf(dirpath="./downloaded_papers", filename=f"{number_without_period}.pdf")
120
 
121
  @tool
122
  def embed_arvix_paper(paper_id:str) -> None:
 
158
  )
159
 
160
  @tool
161
+ def wikipedia_search(query: str) -> str:
162
+ """Search Wikipedia for additional information to expand on research papers or when no papers can be found."""
163
+ global all_sources
164
+
165
+ api_wrapper = WikipediaAPIWrapper()
166
+ wikipedia_search = WikipediaQueryRun(api_wrapper=api_wrapper)
167
+ wikipedia_results = wikipedia_search.run(query)
168
+ all_sources += create_wikipedia_urls_from_text(wikipedia_results)
169
+ return wikipedia_results
170
+
171
+ @tool
172
+ def google_search(query: str) -> str:
173
+ """Search Google for additional results when you can't answer questions using arxiv search or wikipedia search."""
174
+ global all_sources
 
 
 
 
 
175
 
176
+ websearch = GoogleSearchAPIWrapper()
177
+ search_results:dict = websearch.results(query, 3)
178
+ cleaner_sources =format_search_results(search_results)
179
+ parsed_csources = parse_list_to_dicts(cleaner_sources)
180
+ add_many(parsed_csources)
181
+ all_sources += cleaner_sources
182
 
183
+ return cleaner_sources.__str__()
requirements.txt CHANGED
@@ -10,4 +10,5 @@ chromadb
10
  google_api_python_client
11
  pypdf2
12
  sqlmodel
13
- rich
 
 
10
  google_api_python_client
11
  pypdf2
12
  sqlmodel
13
+ rich
14
+ fastapi