Spaces:
Sleeping
Sleeping
import streamlit as st | |
from huggingface_hub import InferenceClient | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.memory import ConversationBufferMemory | |
from langchain.document_loaders import PyPDFLoader | |
import os | |
import tempfile | |
from deep_translator import GoogleTranslator | |
import asyncio | |
import uuid | |
import logging | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
def initialize_session_state(): | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = [] | |
if 'past' not in st.session_state: | |
st.session_state['past'] = [] | |
if 'memory' not in st.session_state: | |
st.session_state['memory'] = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
if 'vector_store' not in st.session_state: | |
st.session_state['vector_store'] = None | |
if 'embeddings' not in st.session_state: | |
st.session_state['embeddings'] = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={'device': 'cpu'}) # Can use CUDA if you want on your device | |
if 'translation_states' not in st.session_state: | |
st.session_state['translation_states'] = {} | |
if 'message_ids' not in st.session_state: | |
st.session_state['message_ids'] = [] | |
if 'is_loading' not in st.session_state: | |
st.session_state['is_loading'] = False | |
async def process_pdf(file): | |
with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
temp_file.write(file.read()) | |
temp_file_path = temp_file.name | |
loader = PyPDFLoader(temp_file_path) | |
text = await asyncio.to_thread(loader.load) | |
os.remove(temp_file_path) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
text_chunks = await asyncio.to_thread(text_splitter.split_documents, text) | |
return text_chunks | |
async def extract_text_from_pdfs(uploaded_files): | |
tasks = [process_pdf(file) for file in uploaded_files] | |
results = await asyncio.gather(*tasks) | |
return [chunk for result in results for chunk in result] | |
def translate_text(text, dest_language='ar'): | |
translator = GoogleTranslator(source='auto', target=dest_language) | |
translation = translator.translate(text) | |
return translation | |
def update_vector_store(new_text_chunks): | |
if st.session_state['vector_store']: | |
st.session_state['vector_store'].add_documents(new_text_chunks) | |
else: | |
st.session_state['vector_store'] = FAISS.from_documents(new_text_chunks, | |
embedding=st.session_state['embeddings']) | |
def get_hf_client(): | |
return InferenceClient( | |
"mistralai/Mistral-Nemo-Instruct-2407", | |
token="hf_********************************" | |
) | |
def retrieve_relevant_chunks(query, max_tokens=1000): | |
if st.session_state['vector_store']: | |
search_results = st.session_state['vector_store'].similarity_search_with_score(query, k=5) | |
relevant_chunks = [] | |
total_tokens = 0 | |
for doc, score in search_results: | |
chunk_tokens = len(doc.page_content.split()) | |
if total_tokens + chunk_tokens > max_tokens: | |
break | |
relevant_chunks.append(doc.page_content) | |
total_tokens += chunk_tokens | |
return "\n".join(relevant_chunks) if relevant_chunks else None | |
return None | |
def generate_response(query, conversation_context, relevant_chunk=None): | |
client = get_hf_client() | |
if relevant_chunk: | |
full_query = f"Based on the following information:\n{relevant_chunk}\n\nAnswer the question: {query}" | |
else: | |
full_query = f"{conversation_context}\nUser: {query}" | |
response = "" | |
try: | |
for message in client.chat_completion( | |
messages=[{"role": "user", "content": full_query}], | |
max_tokens=800, | |
stream=True, | |
temperature=0.3 | |
): | |
response += message.choices[0].delta.content | |
except Exception as e: | |
logging.error(f"Error generating response: {e}") | |
raise | |
return response.strip() | |
def display_chat_interface(): | |
for i in range(len(st.session_state['generated'])): | |
with st.chat_message("user"): | |
st.text(st.session_state["past"][i]) | |
with st.chat_message("assistant"): | |
st.markdown(st.session_state['generated'][i]) | |
if i >= len(st.session_state['message_ids']): | |
message_id = str(uuid.uuid4()) | |
st.session_state['message_ids'].append(message_id) | |
else: | |
message_id = st.session_state['message_ids'][i] | |
translate_key = f"translate_{message_id}" | |
if translate_key not in st.session_state['translation_states']: | |
st.session_state['translation_states'][translate_key] = False | |
if st.button(f"Translate to Arabic", key=f"btn_{translate_key}", on_click=toggle_translation, | |
args=(translate_key,)): | |
pass | |
if st.session_state['translation_states'][translate_key]: | |
with st.spinner("Translating..."): | |
translated_text = translate_text(st.session_state['generated'][i]) | |
st.markdown(f"**Translated:** \n\n {translated_text}") | |
def toggle_translation(translate_key): | |
st.session_state['translation_states'][translate_key] = not st.session_state['translation_states'][translate_key] | |
def get_conversation_context(max_tokens=2000): | |
context = [] | |
total_tokens = 0 | |
for past, generated in zip(reversed(st.session_state['past']), reversed(st.session_state['generated'])): | |
user_message = f"User: {past}\n" | |
assistant_message = f"Assistant: {generated}\n" | |
message_tokens = len(user_message.split()) + len(assistant_message.split()) | |
if total_tokens + message_tokens > max_tokens: | |
break | |
context.insert(0, user_message) | |
context.insert(1, assistant_message) | |
total_tokens += message_tokens | |
return "".join(context) | |
def validate_input(user_input): | |
if not user_input or not user_input.strip(): | |
return False, "Please enter a valid question or command." | |
if len(user_input) > 500: | |
return False, "Your input is too long. Please limit your question to 500 characters." | |
return True, "" | |
def process_user_input(user_input): | |
user_input = user_input.rstrip() | |
is_valid, error_message = validate_input(user_input) | |
if not is_valid: | |
st.error(error_message) | |
return | |
st.session_state['past'].append(user_input) | |
with st.chat_message("user"): | |
st.text(user_input) | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
message_placeholder.markdown("⏳ Thinking...") | |
relevant_chunk = retrieve_relevant_chunks(user_input) | |
conversation_context = get_conversation_context() | |
try: | |
output = generate_response(user_input, conversation_context, relevant_chunk) | |
except Exception as e: | |
logging.error(f"Failed to generate response after retries: {e}") | |
output = "I apologize, but I'm having trouble processing your request at the moment. Please try again later." | |
message_placeholder.empty() | |
message_placeholder.markdown(output) | |
st.session_state['generated'].append(output) | |
st.session_state['memory'].save_context({"input": user_input}, {"output": output}) | |
message_id = str(uuid.uuid4()) | |
st.session_state['message_ids'].append(message_id) | |
translate_key = f"translate_{message_id}" | |
st.session_state['translation_states'][translate_key] = False | |
if st.button(f"Translate to Arabic", key=f"btn_{translate_key}", on_click=toggle_translation, | |
args=(translate_key,)): | |
pass | |
if st.session_state['translation_states'][translate_key]: | |
with st.spinner("Translating..."): | |
translated_text = translate_text(output) | |
st.markdown(f"**Translated:** \n\n {translated_text}") | |
st.rerun() | |
def main(): | |
initialize_session_state() | |
st.title("Chat with PDF Using Mistral AI") | |
uploaded_files = st.sidebar.file_uploader("Upload your PDF files", type="pdf", accept_multiple_files=True) | |
if uploaded_files: | |
with st.spinner("Processing PDF files..."): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
new_text_chunks = loop.run_until_complete(extract_text_from_pdfs(uploaded_files)) | |
update_vector_store(new_text_chunks) | |
st.success("PDF files uploaded and processed successfully.") | |
display_chat_interface() | |
user_input = st.chat_input("Ask about your PDF(s)") | |
if user_input: | |
process_user_input(user_input) | |
if __name__ == "__main__": | |
main() |