|
from langchain.agents import AgentType, Tool, initialize_agent |
|
from langchain.callbacks import StreamlitCallbackHandler |
|
from langchain.chains import RetrievalQA |
|
from langchain.chains.conversation.memory import ConversationBufferMemory |
|
from utils.ask_human import CustomAskHumanTool |
|
from utils.model_params import get_model_params |
|
from utils.prompts import create_agent_prompt, create_qa_prompt |
|
from PyPDF2 import PdfReader |
|
from langchain.vectorstores import FAISS |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.embeddings import HuggingFaceHubEmbeddings |
|
from langchain import HuggingFaceHub |
|
import torch |
|
import streamlit as st |
|
from langchain.utilities import SerpAPIWrapper |
|
from langchain.tools import DuckDuckGoSearchRun |
|
import os |
|
hf_token = os.environ['HF_TOKEN'] |
|
serp_token = os.environ['SERP_TOKEN'] |
|
repo_id = "sentence-transformers/all-mpnet-base-v2" |
|
|
|
|
|
|
|
|
|
HUGGINGFACEHUB_API_TOKEN= hf_token |
|
hf = HuggingFaceHubEmbeddings( |
|
repo_id=repo_id, |
|
task="feature-extraction", |
|
huggingfacehub_api_token= HUGGINGFACEHUB_API_TOKEN, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
llm = HuggingFaceHub( |
|
repo_id='mistralai/Mistral-7B-Instruct-v0.2', |
|
huggingfacehub_api_token = HUGGINGFACEHUB_API_TOKEN, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.chains import RetrievalQA |
|
from langchain import PromptTemplate |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config(page_title="Ask your PDF powered by Search Agents") |
|
st.header("Ask your PDF powered by Search Agents 💬") |
|
|
|
|
|
pdf = st.file_uploader("Upload your PDF and chat with Agent", type="pdf") |
|
|
|
|
|
if pdf is not None: |
|
pdf_reader = PdfReader(pdf) |
|
text = "" |
|
for page in pdf_reader.pages: |
|
text += page.extract_text() |
|
|
|
|
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0) |
|
texts = text_splitter.split_text(text) |
|
|
|
embeddings = hf |
|
knowledge_base = FAISS.from_texts(texts, embeddings) |
|
|
|
retriever = knowledge_base.as_retriever(search_kwargs={"k":3}) |
|
|
|
|
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
return_source_documents=False, |
|
chain_type_kwargs={ |
|
"prompt": create_qa_prompt(), |
|
}, |
|
) |
|
|
|
conversational_memory = ConversationBufferMemory( |
|
memory_key="chat_history", k=3, return_messages=True |
|
) |
|
|
|
|
|
db_search_tool = Tool( |
|
name="dbRetrievalTool", |
|
func=qa_chain, |
|
description="""Use this tool to answer document related questions. The input to this tool should be the question.""", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
search = DuckDuckGoSearchRun() |
|
search_tool = Tool( |
|
name="search", |
|
func=search, |
|
description="use this tool to answer real time or current search related questions." |
|
) |
|
|
|
human_ask_tool = CustomAskHumanTool() |
|
|
|
prefix, format_instructions, suffix = create_agent_prompt() |
|
mode = "Agent with AskHuman tool" |
|
|
|
|
|
agent = initialize_agent( |
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, |
|
tools=[db_search_tool,search_tool], |
|
llm=llm, |
|
verbose=True, |
|
max_iterations=5, |
|
early_stopping_method="generate", |
|
memory=conversational_memory, |
|
agent_kwargs={ |
|
"prefix": prefix, |
|
"format_instructions": format_instructions, |
|
"suffix": suffix, |
|
}, |
|
handle_parsing_errors=True, |
|
|
|
) |
|
|
|
|
|
with st.form(key="form"): |
|
user_input = st.text_input("Ask your question") |
|
submit_clicked = st.form_submit_button("Submit Question") |
|
|
|
|
|
output_container = st.empty() |
|
if submit_clicked: |
|
output_container = output_container.container() |
|
output_container.chat_message("user").write(user_input) |
|
with st.chat_message("assistant"): |
|
st_callback = StreamlitCallbackHandler(st.container()) |
|
response = agent.run(user_input, callbacks=[st_callback]) |
|
st.write(response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |