added duckduckgo search
Browse files- .gitignore +3 -0
- app.py +43 -20
- requirements.txt +5 -0
- utils.py +1 -1
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
.env
|
3 |
+
__pycache__
|
app.py
CHANGED
@@ -1,6 +1,15 @@
|
|
1 |
import streamlit as st
|
2 |
from utils import st_load_retriever, st_load_llm, StreamHandler
|
3 |
from langchain.chains import RetrievalQAWithSourcesChain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
st.title("AIxplorer - A Smarter Google Scholar ππ")
|
6 |
st.write(
|
@@ -13,28 +22,42 @@ st.subheader("Settings")
|
|
13 |
col1, col2, col3 = st.columns(3)
|
14 |
|
15 |
with col1:
|
16 |
-
|
17 |
-
"
|
18 |
-
|
19 |
-
|
|
|
20 |
)
|
|
|
21 |
st.divider()
|
22 |
|
23 |
llm = st_load_llm()
|
24 |
-
retriever = st_load_retriever(llm, "vectordb" if not use_google else "google search")
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
if st.button("Generate"):
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from utils import st_load_retriever, st_load_llm, StreamHandler
|
3 |
from langchain.chains import RetrievalQAWithSourcesChain
|
4 |
+
from langchain.callbacks import StreamlitCallbackHandler
|
5 |
+
from langchain.agents import AgentType, initialize_agent, load_tools
|
6 |
+
|
7 |
+
|
8 |
+
RETRIEVAL_METHOD_MAP = {
|
9 |
+
"Vector Database": "vectordb",
|
10 |
+
"Google Search": "google search",
|
11 |
+
"DuckDuckGo Search": "duckduckgo search",
|
12 |
+
}
|
13 |
|
14 |
st.title("AIxplorer - A Smarter Google Scholar ππ")
|
15 |
st.write(
|
|
|
22 |
col1, col2, col3 = st.columns(3)
|
23 |
|
24 |
with col1:
|
25 |
+
retrieval_method = st.selectbox(
|
26 |
+
"Retrieval Mode",
|
27 |
+
RETRIEVAL_METHOD_MAP.keys(),
|
28 |
+
index=0,
|
29 |
+
help="The retrieval method used to retrieve supporting documents.",
|
30 |
)
|
31 |
+
|
32 |
st.divider()
|
33 |
|
34 |
llm = st_load_llm()
|
|
|
35 |
|
36 |
+
# first path
|
37 |
+
if retrieval_method in ("Vector Database", "Google Search"):
|
38 |
+
retriever = st_load_retriever(llm, RETRIEVAL_METHOD_MAP[retrieval_method])
|
39 |
+
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm, retriever=retriever)
|
40 |
+
user_input = st.text_area(
|
41 |
+
"Enter your query here",
|
42 |
+
help="Query should be on computer science as the RAG system is tuned to that domain.",
|
43 |
+
)
|
44 |
+
if st.button("Generate"):
|
45 |
+
st.divider()
|
46 |
+
st.subheader("Answer:")
|
47 |
+
with st.spinner("Generating..."):
|
48 |
+
container = st.empty()
|
49 |
+
stream_handler = StreamHandler(container)
|
50 |
+
response = qa_chain({"question": user_input}, callbacks=[stream_handler])
|
51 |
+
|
52 |
+
# second path
|
53 |
+
else:
|
54 |
+
tools = load_tools(["ddg-search"])
|
55 |
+
agent = initialize_agent(
|
56 |
+
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
57 |
+
)
|
58 |
+
if prompt := st.chat_input():
|
59 |
+
st.chat_message("user").write(prompt)
|
60 |
+
with st.chat_message("assistant"):
|
61 |
+
st_callback = StreamlitCallbackHandler(st.container())
|
62 |
+
response = agent.run(prompt, callbacks=[st_callback])
|
63 |
+
st.write(response)
|
requirements.txt
CHANGED
@@ -11,14 +11,17 @@ beautifulsoup4==4.12.2
|
|
11 |
blinker==1.7.0
|
12 |
cachetools==5.3.2
|
13 |
certifi==2023.11.17
|
|
|
14 |
charset-normalizer==3.3.2
|
15 |
chroma-hnswlib==0.7.3
|
16 |
chromadb==0.4.20
|
17 |
click==8.1.7
|
18 |
coloredlogs==15.0.1
|
|
|
19 |
dataclasses-json==0.6.3
|
20 |
Deprecated==1.2.14
|
21 |
diskcache==5.6.3
|
|
|
22 |
fastapi==0.105.0
|
23 |
filelock==3.13.1
|
24 |
flatbuffers==23.5.26
|
@@ -54,6 +57,7 @@ langchain-community==0.0.5
|
|
54 |
langchain-core==0.1.2
|
55 |
langsmith==0.0.72
|
56 |
llama_cpp_python==0.2.24
|
|
|
57 |
markdown-it-py==3.0.0
|
58 |
MarkupSafe==2.1.3
|
59 |
marshmallow==3.20.1
|
@@ -100,6 +104,7 @@ pulsar-client==3.3.0
|
|
100 |
pyarrow==14.0.2
|
101 |
pyasn1==0.5.1
|
102 |
pyasn1-modules==0.3.0
|
|
|
103 |
pydantic==2.5.2
|
104 |
pydantic_core==2.14.5
|
105 |
pydeck==0.8.1b0
|
|
|
11 |
blinker==1.7.0
|
12 |
cachetools==5.3.2
|
13 |
certifi==2023.11.17
|
14 |
+
cffi==1.16.0
|
15 |
charset-normalizer==3.3.2
|
16 |
chroma-hnswlib==0.7.3
|
17 |
chromadb==0.4.20
|
18 |
click==8.1.7
|
19 |
coloredlogs==15.0.1
|
20 |
+
curl-cffi==0.5.10
|
21 |
dataclasses-json==0.6.3
|
22 |
Deprecated==1.2.14
|
23 |
diskcache==5.6.3
|
24 |
+
duckduckgo-search==4.1.0
|
25 |
fastapi==0.105.0
|
26 |
filelock==3.13.1
|
27 |
flatbuffers==23.5.26
|
|
|
57 |
langchain-core==0.1.2
|
58 |
langsmith==0.0.72
|
59 |
llama_cpp_python==0.2.24
|
60 |
+
lxml==4.9.4
|
61 |
markdown-it-py==3.0.0
|
62 |
MarkupSafe==2.1.3
|
63 |
marshmallow==3.20.1
|
|
|
104 |
pyarrow==14.0.2
|
105 |
pyasn1==0.5.1
|
106 |
pyasn1-modules==0.3.0
|
107 |
+
pycparser==2.21
|
108 |
pydantic==2.5.2
|
109 |
pydantic_core==2.14.5
|
110 |
pydeck==0.8.1b0
|
utils.py
CHANGED
@@ -21,7 +21,7 @@ class StreamHandler(BaseCallbackHandler):
|
|
21 |
|
22 |
@st.cache_resource
|
23 |
def st_load_retriever(_llm, mode):
|
24 |
-
model_kwargs = {"device":
|
25 |
embeddings_model = HuggingFaceEmbeddings(
|
26 |
model_name=config.embeddings_model,
|
27 |
model_kwargs=model_kwargs,
|
|
|
21 |
|
22 |
@st.cache_resource
|
23 |
def st_load_retriever(_llm, mode):
|
24 |
+
model_kwargs = {"device": config.device}
|
25 |
embeddings_model = HuggingFaceEmbeddings(
|
26 |
model_name=config.embeddings_model,
|
27 |
model_kwargs=model_kwargs,
|