Aabbhishekk commited on
Commit
b4484e7
β€’
1 Parent(s): 7eedacc

Delete app .py

Browse files
Files changed (1) hide show
  1. app .py +0 -177
app .py DELETED
@@ -1,177 +0,0 @@
1
- from langchain.agents import AgentType, Tool, initialize_agent
2
- from langchain.callbacks import StreamlitCallbackHandler
3
- from langchain.chains import RetrievalQA
4
- from langchain.chains.conversation.memory import ConversationBufferMemory
5
- from utils.ask_human import CustomAskHumanTool
6
- from utils.model_params import get_model_params
7
- from utils.prompts import create_agent_prompt, create_qa_prompt
8
- from PyPDF2 import PdfReader
9
- from langchain.vectorstores import FAISS
10
- from langchain.embeddings import HuggingFaceEmbeddings
11
- from langchain.embeddings import HuggingFaceHubEmbeddings
12
- from langchain import HuggingFaceHub
13
- import torch
14
- import streamlit as st
15
- from langchain.utilities import SerpAPIWrapper
16
- import os
17
- hf_token = os.environ['HF_TOKEN']
18
- serp_token = os.environ['SERP_TOKEN']
19
- repo_id = "sentence-transformers/all-mpnet-base-v2"
20
-
21
- HUGGINGFACEHUB_API_TOKEN= hf_token
22
- hf = HuggingFaceHubEmbeddings(
23
- repo_id=repo_id,
24
- task="feature-extraction",
25
- huggingfacehub_api_token= HUGGINGFACEHUB_API_TOKEN,
26
- )
27
-
28
- EMB_SBERT_MPNET_BASE = "sentence-transformers/all-mpnet-base-v2"
29
- config = {"persist_directory":None,
30
- "load_in_8bit":False,
31
- "embedding" : EMB_SBERT_MPNET_BASE
32
- }
33
-
34
-
35
- def create_sbert_mpnet():
36
- device = "cuda" if torch.cuda.is_available() else "cpu"
37
- return HuggingFaceEmbeddings(model_name=EMB_SBERT_MPNET_BASE, model_kwargs={"device": device})
38
-
39
- llm = HuggingFaceHub(
40
- repo_id='mistralai/Mistral-7B-Instruct-v0.2',
41
- huggingfacehub_api_token = HUGGINGFACEHUB_API_TOKEN,
42
-
43
-
44
- )
45
-
46
- if config["embedding"] == EMB_SBERT_MPNET_BASE:
47
- embedding = create_sbert_mpnet()
48
-
49
- from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
50
- from langchain.vectorstores import Chroma
51
- from langchain.chains import RetrievalQA
52
- from langchain import PromptTemplate
53
-
54
- ### PAGE ELEMENTS
55
-
56
- # st.set_page_config(
57
- # page_title="RAG Agent Demo",
58
- # page_icon="🦜",
59
- # layout="centered",
60
- # initial_sidebar_state="collapsed",
61
- # )
62
- # st.markdown("### Leveraging the User to Improve Agents in RAG Use Cases")
63
-
64
-
65
- def main():
66
-
67
- st.set_page_config(page_title="Ask your PDF powered by Search Agents")
68
- st.header("Ask your PDF with RAG Agent πŸ’¬")
69
-
70
- # upload file
71
- pdf = st.file_uploader("Upload your PDF and chat with Agent", type="pdf")
72
-
73
- # extract the text
74
- if pdf is not None:
75
- pdf_reader = PdfReader(pdf)
76
- text = ""
77
- for page in pdf_reader.pages:
78
- text += page.extract_text()
79
-
80
- # Split documents and create text snippets
81
-
82
- text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
83
- texts = text_splitter.split_text(text)
84
-
85
- embeddings = hf
86
- knowledge_base = FAISS.from_texts(texts, embeddings)
87
-
88
- retriever = knowledge_base.as_retriever(search_kwargs={"k":5})
89
- # retriever = FAISS.as_retriever()
90
- # persist_directory = config["persist_directory"]
91
- # vectordb = Chroma.from_documents(documents=texts, embedding=embedding, persist_directory=persist_directory)
92
-
93
- # retriever = vectordb.as_retriever(search_kwargs={"k":5})
94
-
95
- # mode = st.selectbox(
96
- # label="Select agent type",
97
- # options=("Agent with AskHuman tool", "Traditional RAG Agent","Search Agent"),
98
- # )
99
-
100
-
101
-
102
-
103
- qa_chain = RetrievalQA.from_chain_type(
104
- llm=llm,
105
- chain_type="stuff",
106
- retriever=retriever,
107
- return_source_documents=True,
108
- chain_type_kwargs={
109
- "prompt": create_qa_prompt(),
110
- },
111
- )
112
-
113
- conversational_memory = ConversationBufferMemory(
114
- memory_key="chat_history", k=3, return_messages=True
115
- )
116
-
117
- # tool for db search
118
- db_search_tool = Tool(
119
- name="dbRetrievalTool",
120
- func=qa_chain,
121
- description="""Use this tool first to answer human questions. The input to this tool should be the question.""",
122
- )
123
-
124
- search = SerpAPIWrapper(serpapi_api_key=serp_token)
125
-
126
- google_searchtool= Tool(
127
- name="Current Search",
128
- func=search.run,
129
- description="use this tool to answer questions if the answer from other tools are not sufficient.",
130
- )
131
-
132
- # tool for asking human
133
- human_ask_tool = CustomAskHumanTool()
134
- # agent prompt
135
- prefix, format_instructions, suffix = create_agent_prompt()
136
-
137
- # initialize agent
138
- agent = initialize_agent(
139
- agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
140
- tools=[db_search_tool,google_searchtool],
141
- llm=llm,
142
- verbose=True,
143
- max_iterations=5,
144
- early_stopping_method="generate",
145
- memory=conversational_memory,
146
- agent_kwargs={
147
- "prefix": prefix,
148
- "format_instructions": format_instructions,
149
- "suffix": suffix,
150
- },
151
- handle_parsing_errors=True,
152
-
153
- )
154
-
155
- # question form
156
- with st.form(key="form"):
157
- user_input = st.text_input("Ask your question")
158
- submit_clicked = st.form_submit_button("Submit Question")
159
-
160
- # output container
161
- output_container = st.empty()
162
- if submit_clicked:
163
- output_container = output_container.container()
164
- output_container.chat_message("user").write(user_input)
165
-
166
- answer_container = output_container.chat_message("assistant", avatar="🦜")
167
- st_callback = StreamlitCallbackHandler(answer_container)
168
-
169
- answer = agent.run(user_input, callbacks=[st_callback])
170
-
171
- answer_container = output_container.container()
172
- answer_container.chat_message("assistant").write(answer)
173
-
174
-
175
-
176
- if __name__ == '__main__':
177
- main()