Chat-with-PDFs / app.py
AbdalrhmanRi's picture
Update app.py
ee72580 verified
import streamlit as st
from huggingface_hub import InferenceClient
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.document_loaders import PyPDFLoader
import os
import tempfile
from deep_translator import GoogleTranslator
import asyncio
import uuid
import logging
from tenacity import retry, stop_after_attempt, wait_exponential
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def initialize_session_state():
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'memory' not in st.session_state:
st.session_state['memory'] = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
if 'vector_store' not in st.session_state:
st.session_state['vector_store'] = None
if 'embeddings' not in st.session_state:
st.session_state['embeddings'] = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}) # Can use CUDA if you want on your device
if 'translation_states' not in st.session_state:
st.session_state['translation_states'] = {}
if 'message_ids' not in st.session_state:
st.session_state['message_ids'] = []
if 'is_loading' not in st.session_state:
st.session_state['is_loading'] = False
async def process_pdf(file):
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(file.read())
temp_file_path = temp_file.name
loader = PyPDFLoader(temp_file_path)
text = await asyncio.to_thread(loader.load)
os.remove(temp_file_path)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
text_chunks = await asyncio.to_thread(text_splitter.split_documents, text)
return text_chunks
async def extract_text_from_pdfs(uploaded_files):
tasks = [process_pdf(file) for file in uploaded_files]
results = await asyncio.gather(*tasks)
return [chunk for result in results for chunk in result]
@st.cache_data(show_spinner=False)
def translate_text(text, dest_language='ar'):
translator = GoogleTranslator(source='auto', target=dest_language)
translation = translator.translate(text)
return translation
def update_vector_store(new_text_chunks):
if st.session_state['vector_store']:
st.session_state['vector_store'].add_documents(new_text_chunks)
else:
st.session_state['vector_store'] = FAISS.from_documents(new_text_chunks,
embedding=st.session_state['embeddings'])
@st.cache_resource
def get_hf_client():
return InferenceClient(
"mistralai/Mistral-Nemo-Instruct-2407",
token="hf_********************************"
)
def retrieve_relevant_chunks(query, max_tokens=1000):
if st.session_state['vector_store']:
search_results = st.session_state['vector_store'].similarity_search_with_score(query, k=5)
relevant_chunks = []
total_tokens = 0
for doc, score in search_results:
chunk_tokens = len(doc.page_content.split())
if total_tokens + chunk_tokens > max_tokens:
break
relevant_chunks.append(doc.page_content)
total_tokens += chunk_tokens
return "\n".join(relevant_chunks) if relevant_chunks else None
return None
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def generate_response(query, conversation_context, relevant_chunk=None):
client = get_hf_client()
if relevant_chunk:
full_query = f"Based on the following information:\n{relevant_chunk}\n\nAnswer the question: {query}"
else:
full_query = f"{conversation_context}\nUser: {query}"
response = ""
try:
for message in client.chat_completion(
messages=[{"role": "user", "content": full_query}],
max_tokens=800,
stream=True,
temperature=0.3
):
response += message.choices[0].delta.content
except Exception as e:
logging.error(f"Error generating response: {e}")
raise
return response.strip()
def display_chat_interface():
for i in range(len(st.session_state['generated'])):
with st.chat_message("user"):
st.text(st.session_state["past"][i])
with st.chat_message("assistant"):
st.markdown(st.session_state['generated'][i])
if i >= len(st.session_state['message_ids']):
message_id = str(uuid.uuid4())
st.session_state['message_ids'].append(message_id)
else:
message_id = st.session_state['message_ids'][i]
translate_key = f"translate_{message_id}"
if translate_key not in st.session_state['translation_states']:
st.session_state['translation_states'][translate_key] = False
if st.button(f"Translate to Arabic", key=f"btn_{translate_key}", on_click=toggle_translation,
args=(translate_key,)):
pass
if st.session_state['translation_states'][translate_key]:
with st.spinner("Translating..."):
translated_text = translate_text(st.session_state['generated'][i])
st.markdown(f"**Translated:** \n\n {translated_text}")
def toggle_translation(translate_key):
st.session_state['translation_states'][translate_key] = not st.session_state['translation_states'][translate_key]
def get_conversation_context(max_tokens=2000):
context = []
total_tokens = 0
for past, generated in zip(reversed(st.session_state['past']), reversed(st.session_state['generated'])):
user_message = f"User: {past}\n"
assistant_message = f"Assistant: {generated}\n"
message_tokens = len(user_message.split()) + len(assistant_message.split())
if total_tokens + message_tokens > max_tokens:
break
context.insert(0, user_message)
context.insert(1, assistant_message)
total_tokens += message_tokens
return "".join(context)
def validate_input(user_input):
if not user_input or not user_input.strip():
return False, "Please enter a valid question or command."
if len(user_input) > 500:
return False, "Your input is too long. Please limit your question to 500 characters."
return True, ""
def process_user_input(user_input):
user_input = user_input.rstrip()
is_valid, error_message = validate_input(user_input)
if not is_valid:
st.error(error_message)
return
st.session_state['past'].append(user_input)
with st.chat_message("user"):
st.text(user_input)
with st.chat_message("assistant"):
message_placeholder = st.empty()
message_placeholder.markdown("⏳ Thinking...")
relevant_chunk = retrieve_relevant_chunks(user_input)
conversation_context = get_conversation_context()
try:
output = generate_response(user_input, conversation_context, relevant_chunk)
except Exception as e:
logging.error(f"Failed to generate response after retries: {e}")
output = "I apologize, but I'm having trouble processing your request at the moment. Please try again later."
message_placeholder.empty()
message_placeholder.markdown(output)
st.session_state['generated'].append(output)
st.session_state['memory'].save_context({"input": user_input}, {"output": output})
message_id = str(uuid.uuid4())
st.session_state['message_ids'].append(message_id)
translate_key = f"translate_{message_id}"
st.session_state['translation_states'][translate_key] = False
if st.button(f"Translate to Arabic", key=f"btn_{translate_key}", on_click=toggle_translation,
args=(translate_key,)):
pass
if st.session_state['translation_states'][translate_key]:
with st.spinner("Translating..."):
translated_text = translate_text(output)
st.markdown(f"**Translated:** \n\n {translated_text}")
st.rerun()
def main():
initialize_session_state()
st.title("Chat with PDF Using Mistral AI")
uploaded_files = st.sidebar.file_uploader("Upload your PDF files", type="pdf", accept_multiple_files=True)
if uploaded_files:
with st.spinner("Processing PDF files..."):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
new_text_chunks = loop.run_until_complete(extract_text_from_pdfs(uploaded_files))
update_vector_store(new_text_chunks)
st.success("PDF files uploaded and processed successfully.")
display_chat_interface()
user_input = st.chat_input("Ask about your PDF(s)")
if user_input:
process_user_input(user_input)
if __name__ == "__main__":
main()