Lex / app.py
taaha3244's picture
Update app.py
17161a6 verified
raw
history blame
No virus
8.45 kB
import os
import tempfile
import uuid
import streamlit as st
from dotenv import load_dotenv
from qdrant_client import models
from langchain_community.vectorstores import Qdrant
from utils import setup_openai_embeddings,setup_qdrant_client,delete_collection,is_document_embedded
from embed import embed_documents_into_qdrant
from preprocess import split_documents,update_metadata,load_documents_OCR
from retrieve import retrieve_documents,retrieve_documents_from_collection
from summarize import summarize_documents
load_dotenv()
def main():
st.sidebar.title("PDF Management")
uploaded_files = st.sidebar.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
if 'uploaded_collection_name' not in st.session_state:
st.session_state['uploaded_collection_name'] = None
if uploaded_files:
if st.sidebar.button("Add Docs to Data Bank"):
files_info = save_uploaded_files(uploaded_files)
embed_documents_to_data_bank(files_info)
if st.sidebar.button("Add Docs to Current Chat"):
files_info = save_uploaded_files(uploaded_files)
add_docs_to_current_chat(files_info)
pages = {
"Lex Document Summarization": page_summarization,
"Chat with Data Bank": page_qna,
"Chat with Uploaded Docs": page_chat_with_uploaded_docs
}
st.sidebar.title("Page Navigation")
page = st.sidebar.radio("Select a page", tuple(pages.keys()))
# Initialize session state for summarization results if not already set
if 'summaries' not in st.session_state:
st.session_state['summaries'] = {}
# Call the page function based on the user selection
if page:
pages[page](uploaded_files)
def save_uploaded_files(uploaded_files):
"""Save uploaded files to a temporary directory and return their file paths along with original filenames."""
files_info = []
for uploaded_file in uploaded_files:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmpfile:
tmpfile.write(uploaded_file.getvalue())
files_info.append((tmpfile.name, uploaded_file.name))
return files_info
def page_summarization(uploaded_files):
"""Page for document summarization."""
st.title("Lex Document Summarization")
if uploaded_files:
files_info = save_uploaded_files(uploaded_files)
for temp_path, original_name in files_info:
summary_button = st.button(f"Summarize {original_name}", key=original_name)
if summary_button or (original_name in st.session_state['summaries']):
with st.container():
st.write(f"Summary for {original_name}:")
if summary_button: # Only summarize if the button is pressed
try:
documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
summary = summarize_documents(documents, os.getenv('OPENAI_API_KEY'))
st.session_state['summaries'][original_name] = summary # Store summary in session state
except Exception as e:
st.error(f"Failed to summarize {original_name}: {str(e)}")
st.text_area("", value=st.session_state['summaries'][original_name], height=200, key=f"summary_{original_name}")
def page_qna(uploaded_files):
"""Page for Q&A functionality."""
st.title("Chat with Data Bank")
user_query = st.text_area("Enter your question here:", height=300)
if st.button('Get Answer'):
if user_query:
answer = handle_query(user_query)
st.write(answer)
else:
st.error("Please enter a question to get an answer.")
def page_chat_with_uploaded_docs(uploaded_files):
"""Page for chatting with uploaded documents."""
st.title("Chat with Uploaded Documents")
user_query = st.text_area("Enter your question here:", height=300)
if st.button('Get Answer'):
if user_query:
answer = handle_uploaded_docs_query(user_query, st.session_state['uploaded_collection_name'])
st.write(answer)
else:
st.error("Please enter a question to get an answer.")
if st.session_state['uploaded_collection_name']:
if st.button('Delete Embedded Collection'):
collection_name = st.session_state['uploaded_collection_name']
delete_collection(collection_name, os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
st.session_state['uploaded_collection_name'] = None
st.success(f"Deleted collection {collection_name}")
def embed_documents_to_data_bank(files_info):
"""Function to embed documents into the data bank."""
for temp_path, original_name in files_info:
if not is_document_embedded(original_name):
try:
documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
documents = update_metadata(documents, original_name)
documents = split_documents(documents)
if documents:
embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), 'Lex-v1')
st.success(f"Embedded {original_name} into Data Bank")
else:
st.error(f"No documents found or extracted from {original_name}")
except Exception as e:
st.error(f"Failed to embed {original_name}: {str(e)}")
else:
st.info(f"{original_name} is already embedded.")
def add_docs_to_current_chat(files_info):
"""Function to add documents to the current chat session."""
if not st.session_state['uploaded_collection_name']:
st.session_state['uploaded_collection_name'] = f"session-{uuid.uuid4()}"
client = setup_qdrant_client(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
client.create_collection(
collection_name=st.session_state['uploaded_collection_name'],
vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE)
)
else:
client = setup_qdrant_client(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
embeddings_model = setup_openai_embeddings(os.getenv('OPENAI_API_KEY'))
for temp_path, original_name in files_info:
if not is_document_embedded(original_name):
try:
documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
documents = update_metadata(documents, original_name)
documents = split_documents(documents)
if documents:
embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), collection_name=st.session_state['uploaded_collection_name'])
st.success(f"Embedded {original_name}")
else:
st.error(f"No documents found or extracted from {original_name}")
except Exception as e:
st.error(f"Failed to embed {original_name}: {str(e)}")
else:
st.info(f"{original_name} is already embedded.")
def handle_query(query):
"""Retrieve answers based on the query."""
try:
answer = retrieve_documents(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
return answer or "No relevant answer found."
except Exception as e:
return f"Error processing the query: {str(e)}"
def handle_uploaded_docs_query(query, collection_name):
"""Retrieve answers from the uploaded documents collection."""
try:
answer = retrieve_documents_from_collection(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), collection_name)
return answer or "No relevant answer found."
except Exception as e:
return f"Error processing the query: {str(e)}"
def delete_collection(collection_name, qdrant_url, qdrant_api_key):
"""Delete a Qdrant collection."""
client = setup_qdrant_client(qdrant_url, qdrant_api_key)
try:
client.delete_collection(collection_name=collection_name)
except Exception as e:
print("Failed to delete collection:", e)
if __name__ == "__main__":
main()