ffeew commited on
Commit
275da20
β€’
1 Parent(s): f68c440

added duckduckgo search

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. app.py +43 -20
  3. requirements.txt +5 -0
  4. utils.py +1 -1
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv
2
+ .env
3
+ __pycache__
app.py CHANGED
@@ -1,6 +1,15 @@
1
  import streamlit as st
2
  from utils import st_load_retriever, st_load_llm, StreamHandler
3
  from langchain.chains import RetrievalQAWithSourcesChain
 
 
 
 
 
 
 
 
 
4
 
5
  st.title("AIxplorer - A Smarter Google Scholar πŸŒπŸ“š")
6
  st.write(
@@ -13,28 +22,42 @@ st.subheader("Settings")
13
  col1, col2, col3 = st.columns(3)
14
 
15
  with col1:
16
- use_google = st.checkbox(
17
- "Use Google Search",
18
- value=True,
19
- help="Use Google Search to retrieve papers. If unchecked, will use the vector database.",
 
20
  )
 
21
  st.divider()
22
 
23
  llm = st_load_llm()
24
- retriever = st_load_retriever(llm, "vectordb" if not use_google else "google search")
25
 
26
- qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm, retriever=retriever)
27
-
28
- user_input = st.text_area(
29
- "Enter your query here",
30
- help="Query should be on computer science as the RAG system is tuned to that domain.",
31
- )
32
-
33
-
34
- if st.button("Generate"):
35
- st.divider()
36
- st.subheader("Answer:")
37
- with st.spinner("Generating..."):
38
- container = st.empty()
39
- stream_handler = StreamHandler(container)
40
- response = qa_chain({"question": user_input}, callbacks=[stream_handler])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from utils import st_load_retriever, st_load_llm, StreamHandler
3
  from langchain.chains import RetrievalQAWithSourcesChain
4
+ from langchain.callbacks import StreamlitCallbackHandler
5
+ from langchain.agents import AgentType, initialize_agent, load_tools
6
+
7
+
8
+ RETRIEVAL_METHOD_MAP = {
9
+ "Vector Database": "vectordb",
10
+ "Google Search": "google search",
11
+ "DuckDuckGo Search": "duckduckgo search",
12
+ }
13
 
14
  st.title("AIxplorer - A Smarter Google Scholar πŸŒπŸ“š")
15
  st.write(
 
22
  col1, col2, col3 = st.columns(3)
23
 
24
  with col1:
25
+ retrieval_method = st.selectbox(
26
+ "Retrieval Mode",
27
+ RETRIEVAL_METHOD_MAP.keys(),
28
+ index=0,
29
+ help="The retrieval method used to retrieve supporting documents.",
30
  )
31
+
32
  st.divider()
33
 
34
  llm = st_load_llm()
 
35
 
36
+ # first path
37
+ if retrieval_method in ("Vector Database", "Google Search"):
38
+ retriever = st_load_retriever(llm, RETRIEVAL_METHOD_MAP[retrieval_method])
39
+ qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm, retriever=retriever)
40
+ user_input = st.text_area(
41
+ "Enter your query here",
42
+ help="Query should be on computer science as the RAG system is tuned to that domain.",
43
+ )
44
+ if st.button("Generate"):
45
+ st.divider()
46
+ st.subheader("Answer:")
47
+ with st.spinner("Generating..."):
48
+ container = st.empty()
49
+ stream_handler = StreamHandler(container)
50
+ response = qa_chain({"question": user_input}, callbacks=[stream_handler])
51
+
52
+ # second path
53
+ else:
54
+ tools = load_tools(["ddg-search"])
55
+ agent = initialize_agent(
56
+ tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
57
+ )
58
+ if prompt := st.chat_input():
59
+ st.chat_message("user").write(prompt)
60
+ with st.chat_message("assistant"):
61
+ st_callback = StreamlitCallbackHandler(st.container())
62
+ response = agent.run(prompt, callbacks=[st_callback])
63
+ st.write(response)
requirements.txt CHANGED
@@ -11,14 +11,17 @@ beautifulsoup4==4.12.2
11
  blinker==1.7.0
12
  cachetools==5.3.2
13
  certifi==2023.11.17
 
14
  charset-normalizer==3.3.2
15
  chroma-hnswlib==0.7.3
16
  chromadb==0.4.20
17
  click==8.1.7
18
  coloredlogs==15.0.1
 
19
  dataclasses-json==0.6.3
20
  Deprecated==1.2.14
21
  diskcache==5.6.3
 
22
  fastapi==0.105.0
23
  filelock==3.13.1
24
  flatbuffers==23.5.26
@@ -54,6 +57,7 @@ langchain-community==0.0.5
54
  langchain-core==0.1.2
55
  langsmith==0.0.72
56
  llama_cpp_python==0.2.24
 
57
  markdown-it-py==3.0.0
58
  MarkupSafe==2.1.3
59
  marshmallow==3.20.1
@@ -100,6 +104,7 @@ pulsar-client==3.3.0
100
  pyarrow==14.0.2
101
  pyasn1==0.5.1
102
  pyasn1-modules==0.3.0
 
103
  pydantic==2.5.2
104
  pydantic_core==2.14.5
105
  pydeck==0.8.1b0
 
11
  blinker==1.7.0
12
  cachetools==5.3.2
13
  certifi==2023.11.17
14
+ cffi==1.16.0
15
  charset-normalizer==3.3.2
16
  chroma-hnswlib==0.7.3
17
  chromadb==0.4.20
18
  click==8.1.7
19
  coloredlogs==15.0.1
20
+ curl-cffi==0.5.10
21
  dataclasses-json==0.6.3
22
  Deprecated==1.2.14
23
  diskcache==5.6.3
24
+ duckduckgo-search==4.1.0
25
  fastapi==0.105.0
26
  filelock==3.13.1
27
  flatbuffers==23.5.26
 
57
  langchain-core==0.1.2
58
  langsmith==0.0.72
59
  llama_cpp_python==0.2.24
60
+ lxml==4.9.4
61
  markdown-it-py==3.0.0
62
  MarkupSafe==2.1.3
63
  marshmallow==3.20.1
 
104
  pyarrow==14.0.2
105
  pyasn1==0.5.1
106
  pyasn1-modules==0.3.0
107
+ pycparser==2.21
108
  pydantic==2.5.2
109
  pydantic_core==2.14.5
110
  pydeck==0.8.1b0
utils.py CHANGED
@@ -21,7 +21,7 @@ class StreamHandler(BaseCallbackHandler):
21
 
22
  @st.cache_resource
23
  def st_load_retriever(_llm, mode):
24
- model_kwargs = {"device": "cuda"}
25
  embeddings_model = HuggingFaceEmbeddings(
26
  model_name=config.embeddings_model,
27
  model_kwargs=model_kwargs,
 
21
 
22
  @st.cache_resource
23
  def st_load_retriever(_llm, mode):
24
+ model_kwargs = {"device": config.device}
25
  embeddings_model = HuggingFaceEmbeddings(
26
  model_name=config.embeddings_model,
27
  model_kwargs=model_kwargs,