Aabbhishekk commited on
Commit
b7e13eb
β€’
1 Parent(s): b4484e7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.agents import AgentType, Tool, initialize_agent
2
+ from langchain.callbacks import StreamlitCallbackHandler
3
+ from langchain.chains import RetrievalQA
4
+ from langchain.chains.conversation.memory import ConversationBufferMemory
5
+ from utils.ask_human import CustomAskHumanTool
6
+ from utils.model_params import get_model_params
7
+ from utils.prompts import create_agent_prompt, create_qa_prompt
8
+ from PyPDF2 import PdfReader
9
+ from langchain.vectorstores import FAISS
10
+ from langchain.embeddings import HuggingFaceEmbeddings
11
+ from langchain.embeddings import HuggingFaceHubEmbeddings
12
+ from langchain import HuggingFaceHub
13
+ import torch
14
+ import streamlit as st
15
+ from langchain.utilities import SerpAPIWrapper
16
+ import os
17
+ hf_token = os.environ['HF_TOKEN']
18
+ serp_token = os.environ['SERP_TOKEN']
19
+ repo_id = "sentence-transformers/all-mpnet-base-v2"
20
+
21
+ HUGGINGFACEHUB_API_TOKEN= hf_token
22
+ hf = HuggingFaceHubEmbeddings(
23
+ repo_id=repo_id,
24
+ task="feature-extraction",
25
+ huggingfacehub_api_token= HUGGINGFACEHUB_API_TOKEN,
26
+ )
27
+
28
+ EMB_SBERT_MPNET_BASE = "sentence-transformers/all-mpnet-base-v2"
29
+ config = {"persist_directory":None,
30
+ "load_in_8bit":False,
31
+ "embedding" : EMB_SBERT_MPNET_BASE
32
+ }
33
+
34
+
35
+ def create_sbert_mpnet():
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ return HuggingFaceEmbeddings(model_name=EMB_SBERT_MPNET_BASE, model_kwargs={"device": device})
38
+
39
+ llm = HuggingFaceHub(
40
+ repo_id='mistralai/Mistral-7B-Instruct-v0.2',
41
+ huggingfacehub_api_token = HUGGINGFACEHUB_API_TOKEN,
42
+
43
+
44
+ )
45
+
46
+ if config["embedding"] == EMB_SBERT_MPNET_BASE:
47
+ embedding = create_sbert_mpnet()
48
+
49
+ from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
50
+ from langchain.vectorstores import Chroma
51
+ from langchain.chains import RetrievalQA
52
+ from langchain import PromptTemplate
53
+
54
+ ### PAGE ELEMENTS
55
+
56
+ # st.set_page_config(
57
+ # page_title="RAG Agent Demo",
58
+ # page_icon="🦜",
59
+ # layout="centered",
60
+ # initial_sidebar_state="collapsed",
61
+ # )
62
+ # st.markdown("### Leveraging the User to Improve Agents in RAG Use Cases")
63
+
64
+
65
+ def main():
66
+
67
+ st.set_page_config(page_title="Ask your PDF powered by Search Agents")
68
+ st.header("Ask your PDF with RAG Agent πŸ’¬")
69
+
70
+ # upload file
71
+ pdf = st.file_uploader("Upload your PDF and chat with Agent", type="pdf")
72
+
73
+ # extract the text
74
+ if pdf is not None:
75
+ pdf_reader = PdfReader(pdf)
76
+ text = ""
77
+ for page in pdf_reader.pages:
78
+ text += page.extract_text()
79
+
80
+ # Split documents and create text snippets
81
+
82
+ text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
83
+ texts = text_splitter.split_text(text)
84
+
85
+ embeddings = hf
86
+ knowledge_base = FAISS.from_texts(texts, embeddings)
87
+
88
+ retriever = knowledge_base.as_retriever(search_kwargs={"k":5})
89
+ # retriever = FAISS.as_retriever()
90
+ # persist_directory = config["persist_directory"]
91
+ # vectordb = Chroma.from_documents(documents=texts, embedding=embedding, persist_directory=persist_directory)
92
+
93
+ # retriever = vectordb.as_retriever(search_kwargs={"k":5})
94
+
95
+ # mode = st.selectbox(
96
+ # label="Select agent type",
97
+ # options=("Agent with AskHuman tool", "Traditional RAG Agent","Search Agent"),
98
+ # )
99
+
100
+
101
+
102
+
103
+ qa_chain = RetrievalQA.from_chain_type(
104
+ llm=llm,
105
+ chain_type="stuff",
106
+ retriever=retriever,
107
+ return_source_documents=True,
108
+ chain_type_kwargs={
109
+ "prompt": create_qa_prompt(),
110
+ },
111
+ )
112
+
113
+ conversational_memory = ConversationBufferMemory(
114
+ memory_key="chat_history", k=3, return_messages=True
115
+ )
116
+
117
+ # tool for db search
118
+ db_search_tool = Tool(
119
+ name="dbRetrievalTool",
120
+ func=qa_chain,
121
+ description="""Use this tool first to answer human questions. The input to this tool should be the question.""",
122
+ )
123
+
124
+ search = SerpAPIWrapper(serpapi_api_key=serp_token)
125
+
126
+ google_searchtool= Tool(
127
+ name="Current Search",
128
+ func=search.run,
129
+ description="use this tool to answer questions if the answer from other tools are not sufficient.",
130
+ )
131
+
132
+ # tool for asking human
133
+ human_ask_tool = CustomAskHumanTool()
134
+ # agent prompt
135
+ prefix, format_instructions, suffix = create_agent_prompt()
136
+
137
+ # initialize agent
138
+ agent = initialize_agent(
139
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
140
+ tools=[db_search_tool,google_searchtool],
141
+ llm=llm,
142
+ verbose=True,
143
+ max_iterations=5,
144
+ early_stopping_method="generate",
145
+ memory=conversational_memory,
146
+ agent_kwargs={
147
+ "prefix": prefix,
148
+ "format_instructions": format_instructions,
149
+ "suffix": suffix,
150
+ },
151
+ handle_parsing_errors=True,
152
+
153
+ )
154
+
155
+ # question form
156
+ with st.form(key="form"):
157
+ user_input = st.text_input("Ask your question")
158
+ submit_clicked = st.form_submit_button("Submit Question")
159
+
160
+ # output container
161
+ output_container = st.empty()
162
+ if submit_clicked:
163
+ output_container = output_container.container()
164
+ output_container.chat_message("user").write(user_input)
165
+
166
+ answer_container = output_container.chat_message("assistant", avatar="🦜")
167
+ st_callback = StreamlitCallbackHandler(answer_container)
168
+
169
+ answer = agent.run(user_input, callbacks=[st_callback])
170
+
171
+ answer_container = output_container.container()
172
+ answer_container.chat_message("assistant").write(answer)
173
+
174
+
175
+
176
+ if __name__ == '__main__':
177
+ main()