DocGPT / app.py
nickmuchi's picture
Update app.py
1bc07f2 verified
import os
import random
import itertools
import streamlit as st
import validators
from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, WebBaseLoader, PyMuPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_community.chat_models import ChatOpenAI
from langchain.chains import QAGenerationChain
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings, HuggingFaceInstructEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import QAGenerationChain, LLMChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.question_answering import load_qa_chain
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
)
from langchain_openai import ChatOpenAI
from langchain.schema import SystemMessage, AIMessage, HumanMessage
from langchain.prompts import MessagesPlaceholder
from langsmith import Client
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
client = Client()
st.set_page_config(page_title="DOC QA",page_icon=':book:')
starter_message = "Ask me anything about the Doc/Website Input!"
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
'instructor-large': 'hkunlp/instructor-large',
'FlagEmbedding': 'BAAI/bge-base-en-v1.5'}
@st.cache_resource
def create_prompt():
'''Create prompt'''
llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4o")
message = SystemMessage(
content=(
"You are a helpful chatbot who is tasked with answering questions about context given through uploaded documents. "
"Do not dare answer any question not related or relevant to the context given or documents uploaded "
"If there is any ambiguity, politely decline to answer any question not in context provided."
)
)
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=message,
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
)
return prompt, llm
@st.cache_resource
def send_feedback(run_id, score):
client.create_feedback(run_id, "user_score", score=score)
@st.cache_data
def save_file_locally(file):
'''Save uploaded files locally'''
doc_path = os.path.join('tempdir',file.name)
with open(doc_path,'wb') as f:
f.write(file.getbuffer())
return doc_path
@st.cache_data
def load_docs(files, url=False):
if not url:
st.info("`Reading doc ...`")
all_text = ""
documents = []
for file in files:
file_extension = os.path.splitext(file.name)[1]
doc_path = save_file_locally(file)
if file_extension == ".pdf":
pages = PyMuPDFLoader(doc_path)
documents.extend(pages.load())
elif file_extension == ".txt":
#stringio = StringIO(file_path.getvalue().decode("utf-8"))
pages = TextLoader(doc_path)
documents.extend(pages.load())
elif file_extension == ".docx":
#stringio = StringIO(file_path.getvalue().decode("utf-8"))
pages = Docx2txtLoader(doc_path)
documents.extend(pages.load())
else:
st.warning('Please provide txt or pdf or docx.', icon="⚠️")
elif url:
st.info("`Reading web link ...`")
loader = WebBaseLoader(files)
documents = loader.load()
return ','.join([doc.page_content for doc in documents])
@st.cache_data
def gen_embeddings(model_name):
'''Generate embeddings for given model'''
if model_name == 'mpnet-base-v2':
embeddings = HuggingFaceEmbeddings(model_name=bi_enc_dict[model_name])
elif model_name == 'instructor-large':
embeddings = HuggingFaceInstructEmbeddings(model_name=bi_enc_dict[model_name],
query_instruction='Represent the question for retrieving supporting paragraphs: ',
embed_instruction='Represent the paragraph for retrieval: ')
elif model_name == 'FlagEmbedding':
encode_kwargs = {'normalize_embeddings': True}
embeddings = HuggingFaceBgeEmbeddings(model_name=bi_enc_dict[model_name],
encode_kwargs = encode_kwargs
)
return embeddings
@st.cache_resource
def process_corpus(corpus,model_name, chunk_size=1000, overlap=50):
'''Process text for Semantic Search'''
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap)
texts = text_splitter.split_text(corpus)
# Display the number of text chunks
num_chunks = len(texts)
st.write(f"Number of text chunks: {num_chunks}")
#select embedding model
embeddings = gen_embeddings(model_name)
#create vectorstore
vectorstore = FAISS.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 4})
#create retriever tool
tool = create_retriever_tool(
vectorstore,
"search_docs",
"Searches and returns documents using the context provided as a source, relevant to the user input question.",
)
tools = [tool]
return tools
prompt, llm = create_prompt()
@st.cache_resource
def generate_memory(text,model_name):
'''Generate the memory functionality'''
tools = process_corpus(text,model_name)
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
return_intermediate_steps=True,
)
memory = AgentTokenBufferMemory(llm=llm)
return memory, agent_executor
@st.cache_data
def generate_eval(raw_text, N, chunk):
# Generate N questions from context of chunk chars
# IN: text, N questions, chunk size to draw question from in the doc
# OUT: eval set as JSON list
# raw_text = ','.join(raw_text)
print(raw_text)
update = st.empty()
ques_update = st.empty()
update.info("`Generating sample questions ...`")
n = len(raw_text)
starting_indices = [random.randint(0, n-chunk) for _ in range(N)]
sub_sequences = [raw_text[i:i+chunk] for i in starting_indices]
chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0))
eval_set = []
for i, b in enumerate(sub_sequences):
try:
qa = chain.run(b)
eval_set.append(qa)
ques_update.info(f"Creating Question: {i+1}")
except:
st.warning(f'Error in generating Question: {i+1}...', icon="⚠️")
continue
eval_set_full = list(itertools.chain.from_iterable(eval_set))
update.empty()
ques_update.empty()
return eval_set_full
def gen_side_bar_qa(text,model_name):
'''Generate responses from query'''
print(f'Tessst: {text}')
if text:
# Check if there are no generated question-answer pairs in the session state
if 'eval_set' not in st.session_state:
# Use the generate_eval function to generate question-answer pairs
num_eval_questions = 10 # Number of question-answer pairs to generate
st.session_state.eval_set = generate_eval(text, num_eval_questions, 3000)
# Display the question-answer pairs in the sidebar with smaller text
for i, qa_pair in enumerate(st.session_state.eval_set):
st.sidebar.markdown(
f"""
<div class="css-card">
<span class="card-tag">Question {i + 1}</span>
<p style="font-size: 12px;">{qa_pair['question']}</p>
<p style="font-size: 12px;">{qa_pair['answer']}</p>
</div>
""",
unsafe_allow_html=True,
)
st.write("Ready to answer your questions.")
# Add custom CSS
st.markdown(
"""
<style>
#MainMenu {visibility: hidden;
# }
footer {visibility: hidden;
}
.css-card {
border-radius: 0px;
padding: 30px 10px 10px 10px;
background-color: black;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
margin-bottom: 10px;
font-family: "IBM Plex Sans", sans-serif;
}
.card-tag {
border-radius: 0px;
padding: 1px 5px 1px 5px;
margin-bottom: 10px;
position: absolute;
left: 0px;
top: 0px;
font-size: 0.6rem;
font-family: "IBM Plex Sans", sans-serif;
color: white;
background-color: green;
}
.css-zt5igj {left:0;
}
span.css-10trblm {margin-left:0;
}
div.css-1kyxreq {margin-top: -40px;
}
</style>
""",
unsafe_allow_html=True,
)
st.sidebar.image("img/logo.jpg")
st.write(
f"""
<div style="display: flex; align-items: center; margin-left: 0;">
<h1 style="display: inline-block;">DOC GPT</h1>
<sup style="margin-left:5px;font-size:small; color: green;">beta</sup>
</div>
""",
unsafe_allow_html=True,
)
st.sidebar.title("Menu")
# Use RecursiveCharacterTextSplitter as the default and only text splitter
splitter_type = "RecursiveCharacterTextSplitter"
uploaded_files = st.file_uploader("Upload a PDF or TXT or DOCX Document", type=[
"pdf", "txt", "docx"], accept_multiple_files=True)
st.markdown(
"<h3 style='text-align: center; color: red;'>OR</h3>",
unsafe_allow_html=True,
)
url_text = st.text_input("Please Enter a url here for an html file you would like to load..")
model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
if uploaded_files:
# Check if last_uploaded_files is not in session_state or if uploaded_files are different from last_uploaded_files
if 'last_uploaded_files' not in st.session_state or st.session_state.last_uploaded_files != uploaded_files:
st.session_state.last_uploaded_files = uploaded_files
if 'eval_set' in st.session_state:
del st.session_state['eval_set']
# Load and process the uploaded PDF or TXT files.
raw_text = load_docs(uploaded_files)
st.success("Documents uploaded and processed.")
# # Question and answering
# user_question = st.text_input("Enter your question:")
gen_side_bar_qa(raw_text,model_name)
memory, agent_executor = generate_memory(raw_text,model_name)
if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
st.session_state["messages"] = [AIMessage(content=starter_message)]
for msg in st.session_state.messages:
if isinstance(msg, AIMessage):
st.chat_message("assistant").write(msg.content)
elif isinstance(msg, HumanMessage):
st.chat_message("user").write(msg.content)
memory.chat_memory.add_message(msg)
if user_question := st.chat_input(placeholder=starter_message):
st.chat_message("user").write(user_question)
with st.chat_message("assistant"):
st_callback = StreamlitCallbackHandler(st.container())
response = agent_executor(
{"input": user_question, "history": st.session_state.messages},
callbacks=[st_callback],
include_run_info=True,
)
st.session_state.messages.append(AIMessage(content=response["output"]))
st.write(response["output"])
memory.save_context({"input": user_question}, response)
st.session_state["messages"] = memory.buffer
run_id = response["__run"].run_id
col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1])
with col_text:
st.text("Feedback:")
with col1:
st.button("πŸ‘", on_click=send_feedback, args=(run_id, 1))
with col2:
st.button("πŸ‘Ž", on_click=send_feedback, args=(run_id, 0))
elif url_text and validators.url(url_text):
# Check if last_uploaded_files is not in session_state or if uploaded_files are different from last_uploaded_files
if 'url_files' not in st.session_state or st.session_state.url_files != url_text:
st.session_state.url_files = url_text
if 'eval_set' in st.session_state:
del st.session_state['eval_set']
# Load and process the uploaded PDF or TXT files.
loaded_docs = load_docs(url_text,url=True)
st.success("Web Document uploaded and processed.")
gen_side_bar_qa(loaded_docs,model_name)
memory, agent_executor = generate_memory(loaded_docs,model_name)
if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
st.session_state["messages"] = [AIMessage(content=starter_message)]
for msg in st.session_state.messages:
if isinstance(msg, AIMessage):
st.chat_message("assistant").write(msg.content)
elif isinstance(msg, HumanMessage):
st.chat_message("user").write(msg.content)
memory.chat_memory.add_message(msg)
if user_question := st.chat_input(placeholder=starter_message):
st.chat_message("user").write(user_question)
with st.chat_message("assistant"):
st_callback = StreamlitCallbackHandler(st.container())
response = agent_executor(
{"input": user_question, "history": st.session_state.messages},
callbacks=[st_callback],
include_run_info=True,
)
st.session_state.messages.append(AIMessage(content=response["output"]))
st.write(response["output"])
memory.save_context({"input": user_question}, response)
st.session_state["messages"] = memory.buffer
run_id = response["__run"].run_id
col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1])
with col_text:
st.text("Feedback:")
with col1:
st.button("πŸ‘", on_click=send_feedback, args=(run_id, 1))
with col2:
st.button("πŸ‘Ž", on_click=send_feedback, args=(run_id, 0))
st.markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-doc-gpt)")