File size: 5,845 Bytes
b7e13eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from langchain.agents import AgentType, Tool, initialize_agent
from langchain.callbacks import StreamlitCallbackHandler
from langchain.chains import RetrievalQA
from langchain.chains.conversation.memory import ConversationBufferMemory
from utils.ask_human import CustomAskHumanTool
from utils.model_params import get_model_params
from utils.prompts import create_agent_prompt, create_qa_prompt
from PyPDF2 import PdfReader
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain import HuggingFaceHub
import torch
import streamlit as st
from langchain.utilities import SerpAPIWrapper
import os
hf_token = os.environ['HF_TOKEN']
serp_token = os.environ['SERP_TOKEN']
repo_id = "sentence-transformers/all-mpnet-base-v2"
HUGGINGFACEHUB_API_TOKEN= hf_token
hf = HuggingFaceHubEmbeddings(
repo_id=repo_id,
task="feature-extraction",
huggingfacehub_api_token= HUGGINGFACEHUB_API_TOKEN,
)
EMB_SBERT_MPNET_BASE = "sentence-transformers/all-mpnet-base-v2"
config = {"persist_directory":None,
"load_in_8bit":False,
"embedding" : EMB_SBERT_MPNET_BASE
}
def create_sbert_mpnet():
device = "cuda" if torch.cuda.is_available() else "cpu"
return HuggingFaceEmbeddings(model_name=EMB_SBERT_MPNET_BASE, model_kwargs={"device": device})
llm = HuggingFaceHub(
repo_id='mistralai/Mistral-7B-Instruct-v0.2',
huggingfacehub_api_token = HUGGINGFACEHUB_API_TOKEN,
)
if config["embedding"] == EMB_SBERT_MPNET_BASE:
embedding = create_sbert_mpnet()
from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain import PromptTemplate
### PAGE ELEMENTS
# st.set_page_config(
# page_title="RAG Agent Demo",
# page_icon="🦜",
# layout="centered",
# initial_sidebar_state="collapsed",
# )
# st.markdown("### Leveraging the User to Improve Agents in RAG Use Cases")
def main():
st.set_page_config(page_title="Ask your PDF powered by Search Agents")
st.header("Ask your PDF with RAG Agent 💬")
# upload file
pdf = st.file_uploader("Upload your PDF and chat with Agent", type="pdf")
# extract the text
if pdf is not None:
pdf_reader = PdfReader(pdf)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
# Split documents and create text snippets
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
texts = text_splitter.split_text(text)
embeddings = hf
knowledge_base = FAISS.from_texts(texts, embeddings)
retriever = knowledge_base.as_retriever(search_kwargs={"k":5})
# retriever = FAISS.as_retriever()
# persist_directory = config["persist_directory"]
# vectordb = Chroma.from_documents(documents=texts, embedding=embedding, persist_directory=persist_directory)
# retriever = vectordb.as_retriever(search_kwargs={"k":5})
# mode = st.selectbox(
# label="Select agent type",
# options=("Agent with AskHuman tool", "Traditional RAG Agent","Search Agent"),
# )
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={
"prompt": create_qa_prompt(),
},
)
conversational_memory = ConversationBufferMemory(
memory_key="chat_history", k=3, return_messages=True
)
# tool for db search
db_search_tool = Tool(
name="dbRetrievalTool",
func=qa_chain,
description="""Use this tool first to answer human questions. The input to this tool should be the question.""",
)
search = SerpAPIWrapper(serpapi_api_key=serp_token)
google_searchtool= Tool(
name="Current Search",
func=search.run,
description="use this tool to answer questions if the answer from other tools are not sufficient.",
)
# tool for asking human
human_ask_tool = CustomAskHumanTool()
# agent prompt
prefix, format_instructions, suffix = create_agent_prompt()
# initialize agent
agent = initialize_agent(
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
tools=[db_search_tool,google_searchtool],
llm=llm,
verbose=True,
max_iterations=5,
early_stopping_method="generate",
memory=conversational_memory,
agent_kwargs={
"prefix": prefix,
"format_instructions": format_instructions,
"suffix": suffix,
},
handle_parsing_errors=True,
)
# question form
with st.form(key="form"):
user_input = st.text_input("Ask your question")
submit_clicked = st.form_submit_button("Submit Question")
# output container
output_container = st.empty()
if submit_clicked:
output_container = output_container.container()
output_container.chat_message("user").write(user_input)
answer_container = output_container.chat_message("assistant", avatar="🦜")
st_callback = StreamlitCallbackHandler(answer_container)
answer = agent.run(user_input, callbacks=[st_callback])
answer_container = output_container.container()
answer_container.chat_message("assistant").write(answer)
if __name__ == '__main__':
main() |