Spaces:
Running
Running
import os | |
import random | |
import itertools | |
import streamlit as st | |
import validators | |
from langchain_community.callbacks.streamlit import StreamlitCallbackHandler | |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, WebBaseLoader, PyMuPDFLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain.chains import QAGenerationChain | |
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings, HuggingFaceInstructEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import QAGenerationChain, LLMChain | |
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor | |
from langchain.agents.agent_toolkits import create_retriever_tool | |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( | |
AgentTokenBufferMemory, | |
) | |
from langchain_openai import ChatOpenAI | |
from langchain.schema import SystemMessage, AIMessage, HumanMessage | |
from langchain.prompts import MessagesPlaceholder | |
from langsmith import Client | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
AIMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
client = Client() | |
st.set_page_config(page_title="DOC QA",page_icon=':book:') | |
starter_message = "Ask me anything about the Doc/Website Input!" | |
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2", | |
'instructor-large': 'hkunlp/instructor-large', | |
'FlagEmbedding': 'BAAI/bge-base-en-v1.5'} | |
def create_prompt(): | |
'''Create prompt''' | |
llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4o") | |
message = SystemMessage( | |
content=( | |
"You are a helpful chatbot who is tasked with answering questions about context given through uploaded documents. " | |
"Do not dare answer any question not related or relevant to the context given or documents uploaded " | |
"If there is any ambiguity, politely decline to answer any question not in context provided." | |
) | |
) | |
prompt = OpenAIFunctionsAgent.create_prompt( | |
system_message=message, | |
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], | |
) | |
return prompt, llm | |
def send_feedback(run_id, score): | |
client.create_feedback(run_id, "user_score", score=score) | |
def save_file_locally(file): | |
'''Save uploaded files locally''' | |
doc_path = os.path.join('tempdir',file.name) | |
with open(doc_path,'wb') as f: | |
f.write(file.getbuffer()) | |
return doc_path | |
def load_docs(files, url=False): | |
if not url: | |
st.info("`Reading doc ...`") | |
all_text = "" | |
documents = [] | |
for file in files: | |
file_extension = os.path.splitext(file.name)[1] | |
doc_path = save_file_locally(file) | |
if file_extension == ".pdf": | |
pages = PyMuPDFLoader(doc_path) | |
documents.extend(pages.load()) | |
elif file_extension == ".txt": | |
#stringio = StringIO(file_path.getvalue().decode("utf-8")) | |
pages = TextLoader(doc_path) | |
documents.extend(pages.load()) | |
elif file_extension == ".docx": | |
#stringio = StringIO(file_path.getvalue().decode("utf-8")) | |
pages = Docx2txtLoader(doc_path) | |
documents.extend(pages.load()) | |
else: | |
st.warning('Please provide txt or pdf or docx.', icon="β οΈ") | |
elif url: | |
st.info("`Reading web link ...`") | |
loader = WebBaseLoader(files) | |
documents = loader.load() | |
return ','.join([doc.page_content for doc in documents]) | |
def gen_embeddings(model_name): | |
'''Generate embeddings for given model''' | |
if model_name == 'mpnet-base-v2': | |
embeddings = HuggingFaceEmbeddings(model_name=bi_enc_dict[model_name]) | |
elif model_name == 'instructor-large': | |
embeddings = HuggingFaceInstructEmbeddings(model_name=bi_enc_dict[model_name], | |
query_instruction='Represent the question for retrieving supporting paragraphs: ', | |
embed_instruction='Represent the paragraph for retrieval: ') | |
elif model_name == 'FlagEmbedding': | |
encode_kwargs = {'normalize_embeddings': True} | |
embeddings = HuggingFaceBgeEmbeddings(model_name=bi_enc_dict[model_name], | |
encode_kwargs = encode_kwargs | |
) | |
return embeddings | |
def process_corpus(corpus,model_name, chunk_size=1000, overlap=50): | |
'''Process text for Semantic Search''' | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap) | |
texts = text_splitter.split_text(corpus) | |
# Display the number of text chunks | |
num_chunks = len(texts) | |
st.write(f"Number of text chunks: {num_chunks}") | |
#select embedding model | |
embeddings = gen_embeddings(model_name) | |
#create vectorstore | |
vectorstore = FAISS.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 4}) | |
#create retriever tool | |
tool = create_retriever_tool( | |
vectorstore, | |
"search_docs", | |
"Searches and returns documents using the context provided as a source, relevant to the user input question.", | |
) | |
tools = [tool] | |
return tools | |
prompt, llm = create_prompt() | |
def generate_memory(text,model_name): | |
'''Generate the memory functionality''' | |
tools = process_corpus(text,model_name) | |
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) | |
agent_executor = AgentExecutor( | |
agent=agent, | |
tools=tools, | |
verbose=True, | |
return_intermediate_steps=True, | |
) | |
memory = AgentTokenBufferMemory(llm=llm) | |
return memory, agent_executor | |
def generate_eval(raw_text, N, chunk): | |
# Generate N questions from context of chunk chars | |
# IN: text, N questions, chunk size to draw question from in the doc | |
# OUT: eval set as JSON list | |
# raw_text = ','.join(raw_text) | |
print(raw_text) | |
update = st.empty() | |
ques_update = st.empty() | |
update.info("`Generating sample questions ...`") | |
n = len(raw_text) | |
starting_indices = [random.randint(0, n-chunk) for _ in range(N)] | |
sub_sequences = [raw_text[i:i+chunk] for i in starting_indices] | |
chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0)) | |
eval_set = [] | |
for i, b in enumerate(sub_sequences): | |
try: | |
qa = chain.run(b) | |
eval_set.append(qa) | |
ques_update.info(f"Creating Question: {i+1}") | |
except: | |
st.warning(f'Error in generating Question: {i+1}...', icon="β οΈ") | |
continue | |
eval_set_full = list(itertools.chain.from_iterable(eval_set)) | |
update.empty() | |
ques_update.empty() | |
return eval_set_full | |
def gen_side_bar_qa(text,model_name): | |
'''Generate responses from query''' | |
print(f'Tessst: {text}') | |
if text: | |
# Check if there are no generated question-answer pairs in the session state | |
if 'eval_set' not in st.session_state: | |
# Use the generate_eval function to generate question-answer pairs | |
num_eval_questions = 10 # Number of question-answer pairs to generate | |
st.session_state.eval_set = generate_eval(text, num_eval_questions, 3000) | |
# Display the question-answer pairs in the sidebar with smaller text | |
for i, qa_pair in enumerate(st.session_state.eval_set): | |
st.sidebar.markdown( | |
f""" | |
<div class="css-card"> | |
<span class="card-tag">Question {i + 1}</span> | |
<p style="font-size: 12px;">{qa_pair['question']}</p> | |
<p style="font-size: 12px;">{qa_pair['answer']}</p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.write("Ready to answer your questions.") | |
# Add custom CSS | |
st.markdown( | |
""" | |
<style> | |
#MainMenu {visibility: hidden; | |
# } | |
footer {visibility: hidden; | |
} | |
.css-card { | |
border-radius: 0px; | |
padding: 30px 10px 10px 10px; | |
background-color: black; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
margin-bottom: 10px; | |
font-family: "IBM Plex Sans", sans-serif; | |
} | |
.card-tag { | |
border-radius: 0px; | |
padding: 1px 5px 1px 5px; | |
margin-bottom: 10px; | |
position: absolute; | |
left: 0px; | |
top: 0px; | |
font-size: 0.6rem; | |
font-family: "IBM Plex Sans", sans-serif; | |
color: white; | |
background-color: green; | |
} | |
.css-zt5igj {left:0; | |
} | |
span.css-10trblm {margin-left:0; | |
} | |
div.css-1kyxreq {margin-top: -40px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.image("img/logo.jpg") | |
st.write( | |
f""" | |
<div style="display: flex; align-items: center; margin-left: 0;"> | |
<h1 style="display: inline-block;">DOC GPT</h1> | |
<sup style="margin-left:5px;font-size:small; color: green;">beta</sup> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.title("Menu") | |
# Use RecursiveCharacterTextSplitter as the default and only text splitter | |
splitter_type = "RecursiveCharacterTextSplitter" | |
uploaded_files = st.file_uploader("Upload a PDF or TXT or DOCX Document", type=[ | |
"pdf", "txt", "docx"], accept_multiple_files=True) | |
st.markdown( | |
"<h3 style='text-align: center; color: red;'>OR</h3>", | |
unsafe_allow_html=True, | |
) | |
url_text = st.text_input("Please Enter a url here for an html file you would like to load..") | |
model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox') | |
if uploaded_files: | |
# Check if last_uploaded_files is not in session_state or if uploaded_files are different from last_uploaded_files | |
if 'last_uploaded_files' not in st.session_state or st.session_state.last_uploaded_files != uploaded_files: | |
st.session_state.last_uploaded_files = uploaded_files | |
if 'eval_set' in st.session_state: | |
del st.session_state['eval_set'] | |
# Load and process the uploaded PDF or TXT files. | |
raw_text = load_docs(uploaded_files) | |
st.success("Documents uploaded and processed.") | |
# # Question and answering | |
# user_question = st.text_input("Enter your question:") | |
gen_side_bar_qa(raw_text,model_name) | |
memory, agent_executor = generate_memory(raw_text,model_name) | |
if "messages" not in st.session_state or st.sidebar.button("Clear message history"): | |
st.session_state["messages"] = [AIMessage(content=starter_message)] | |
for msg in st.session_state.messages: | |
if isinstance(msg, AIMessage): | |
st.chat_message("assistant").write(msg.content) | |
elif isinstance(msg, HumanMessage): | |
st.chat_message("user").write(msg.content) | |
memory.chat_memory.add_message(msg) | |
if user_question := st.chat_input(placeholder=starter_message): | |
st.chat_message("user").write(user_question) | |
with st.chat_message("assistant"): | |
st_callback = StreamlitCallbackHandler(st.container()) | |
response = agent_executor( | |
{"input": user_question, "history": st.session_state.messages}, | |
callbacks=[st_callback], | |
include_run_info=True, | |
) | |
st.session_state.messages.append(AIMessage(content=response["output"])) | |
st.write(response["output"]) | |
memory.save_context({"input": user_question}, response) | |
st.session_state["messages"] = memory.buffer | |
run_id = response["__run"].run_id | |
col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1]) | |
with col_text: | |
st.text("Feedback:") | |
with col1: | |
st.button("π", on_click=send_feedback, args=(run_id, 1)) | |
with col2: | |
st.button("π", on_click=send_feedback, args=(run_id, 0)) | |
elif url_text and validators.url(url_text): | |
# Check if last_uploaded_files is not in session_state or if uploaded_files are different from last_uploaded_files | |
if 'url_files' not in st.session_state or st.session_state.url_files != url_text: | |
st.session_state.url_files = url_text | |
if 'eval_set' in st.session_state: | |
del st.session_state['eval_set'] | |
# Load and process the uploaded PDF or TXT files. | |
loaded_docs = load_docs(url_text,url=True) | |
st.success("Web Document uploaded and processed.") | |
gen_side_bar_qa(loaded_docs,model_name) | |
memory, agent_executor = generate_memory(loaded_docs,model_name) | |
if "messages" not in st.session_state or st.sidebar.button("Clear message history"): | |
st.session_state["messages"] = [AIMessage(content=starter_message)] | |
for msg in st.session_state.messages: | |
if isinstance(msg, AIMessage): | |
st.chat_message("assistant").write(msg.content) | |
elif isinstance(msg, HumanMessage): | |
st.chat_message("user").write(msg.content) | |
memory.chat_memory.add_message(msg) | |
if user_question := st.chat_input(placeholder=starter_message): | |
st.chat_message("user").write(user_question) | |
with st.chat_message("assistant"): | |
st_callback = StreamlitCallbackHandler(st.container()) | |
response = agent_executor( | |
{"input": user_question, "history": st.session_state.messages}, | |
callbacks=[st_callback], | |
include_run_info=True, | |
) | |
st.session_state.messages.append(AIMessage(content=response["output"])) | |
st.write(response["output"]) | |
memory.save_context({"input": user_question}, response) | |
st.session_state["messages"] = memory.buffer | |
run_id = response["__run"].run_id | |
col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1]) | |
with col_text: | |
st.text("Feedback:") | |
with col1: | |
st.button("π", on_click=send_feedback, args=(run_id, 1)) | |
with col2: | |
st.button("π", on_click=send_feedback, args=(run_id, 0)) | |
st.markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-doc-gpt)") |