Spaces:
Runtime error
Runtime error
File size: 9,408 Bytes
8999dd1 17161a6 8999dd1 df07373 17161a6 ace378a 17161a6 ace378a 17161a6 ace378a 8999dd1 17161a6 ace378a 17161a6 df07373 17161a6 8999dd1 df07373 ace378a 17161a6 df07373 8999dd1 df07373 8999dd1 df07373 17161a6 8999dd1 17161a6 8999dd1 17161a6 df07373 8999dd1 df07373 17161a6 df07373 17161a6 df07373 ace378a df07373 17161a6 df07373 17161a6 ace378a 17161a6 ace378a 17161a6 df07373 17161a6 ace378a df07373 ace378a 17161a6 df07373 17161a6 8999dd1 df07373 8999dd1 17161a6 df07373 17161a6 df07373 8999dd1 17161a6 8999dd1 df07373 17161a6 df07373 8999dd1 17161a6 8999dd1 ace378a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import os
import tempfile
import uuid
import streamlit as st
from dotenv import load_dotenv
from qdrant_client import models
from utils import setup_openai_embeddings, setup_qdrant_client, delete_collection, is_document_embedded
from embed import embed_documents_into_qdrant
from preprocess import split_documents, update_metadata, load_documents_OCR
from retrieve import retrieve_documents, retrieve_documents_from_collection
from summarize import summarize_documents
# Load environment variables
load_dotenv()
def main():
st.sidebar.title("PDF Management")
uploaded_files = st.sidebar.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
if 'uploaded_collection_name' not in st.session_state:
st.session_state['uploaded_collection_name'] = None
# Initialize session state for storing answers if not already set
if 'data_bank_answer' not in st.session_state:
st.session_state['data_bank_answer'] = None
if 'uploaded_docs_answer' not in st.session_state:
st.session_state['uploaded_docs_answer'] = None
if uploaded_files:
if st.sidebar.button("Add Docs to Data Bank"):
files_info = save_uploaded_files(uploaded_files)
embed_documents_to_data_bank(files_info)
if st.sidebar.button("Add Docs to Current Chat"):
files_info = save_uploaded_files(uploaded_files)
add_docs_to_current_chat(files_info)
pages = {
"Lex Document Summarization": page_summarization,
"Chat with Data Bank": page_qna,
"Chat with Uploaded Docs": page_chat_with_uploaded_docs
}
st.sidebar.title("Page Navigation")
page = st.sidebar.radio("Select a page", tuple(pages.keys()))
# Initialize session state for summarization results if not already set
if 'summaries' not in st.session_state:
st.session_state['summaries'] = {}
# Call the page function based on the user selection
if page:
pages[page](uploaded_files)
def save_uploaded_files(uploaded_files):
"""Save uploaded files to a temporary directory and return their file paths along with original filenames."""
files_info = []
for uploaded_file in uploaded_files:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmpfile:
tmpfile.write(uploaded_file.getvalue())
files_info.append((tmpfile.name, uploaded_file.name))
return files_info
def page_summarization(uploaded_files):
"""Page for document summarization."""
st.title("Lex Document Summarization")
if uploaded_files:
files_info = save_uploaded_files(uploaded_files)
for temp_path, original_name in files_info:
summary_button = st.button(f"Summarize {original_name}", key=original_name)
if summary_button or (original_name in st.session_state['summaries']):
with st.container():
st.write(f"Summary for {original_name}:")
if summary_button: # Only summarize if the button is pressed
try:
documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
summary = summarize_documents(documents, os.getenv('OPENAI_API_KEY'))
st.session_state['summaries'][original_name] = summary # Store summary in session state
except Exception as e:
st.error(f"Failed to summarize {original_name}: {str(e)}")
if original_name in st.session_state['summaries']:
st.text_area("", value=st.session_state['summaries'][original_name], height=1000, key=f"summary_{original_name}")
else:
st.error(f"No summary found for {original_name}. Please click the summarize button.")
def page_qna(uploaded_files):
"""Page for Q&A functionality."""
st.title("Chat with Data Bank")
user_query = st.text_area("Enter your question here:", height=300)
if st.button('Get Answer'):
if user_query:
answer = handle_query(user_query)
st.session_state['data_bank_answer'] = answer # Store the answer in session state
st.write(answer)
else:
st.error("Please enter a question to get an answer.")
# Display stored answer if it exists
if st.session_state['data_bank_answer']:
st.write(st.session_state['data_bank_answer'])
def page_chat_with_uploaded_docs(uploaded_files):
"""Page for chatting with uploaded documents."""
st.title("Chat with Uploaded Documents")
user_query = st.text_area("Enter your question here:", height=300)
if st.button('Get Answer'):
if user_query:
answer = handle_uploaded_docs_query(user_query, st.session_state['uploaded_collection_name'])
st.session_state['uploaded_docs_answer'] = answer # Store the answer in session state
st.write(answer)
else:
st.error("Please enter a question to get an answer.")
# Display stored answer if it exists
if st.session_state['uploaded_docs_answer']:
st.write(st.session_state['uploaded_docs_answer'])
if st.session_state['uploaded_collection_name']:
if st.button('Delete Embedded Collection'):
collection_name = st.session_state['uploaded_collection_name']
delete_collection(collection_name, os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
st.session_state['uploaded_collection_name'] = None
st.success(f"Deleted collection {collection_name}")
def embed_documents_to_data_bank(files_info):
"""Function to embed documents into the data bank."""
for temp_path, original_name in files_info:
if not is_document_embedded(original_name):
try:
documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
documents = update_metadata(documents, original_name)
documents = split_documents(documents)
if documents:
embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), 'Lex-v1')
st.success(f"Embedded {original_name} into Data Bank")
else:
st.error(f"No documents found or extracted from {original_name}")
except Exception as e:
st.error(f"Failed to embed {original_name}: {str(e)}")
else:
st.info(f"{original_name} is already embedded.")
def add_docs_to_current_chat(files_info):
"""Function to add documents to the current chat session."""
if not st.session_state['uploaded_collection_name']:
st.session_state['uploaded_collection_name'] = f"session-{uuid.uuid4()}"
client = setup_qdrant_client(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
client.create_collection(
collection_name=st.session_state['uploaded_collection_name'],
vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE)
)
else:
client = setup_qdrant_client(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
embeddings_model = setup_openai_embeddings(os.getenv('OPENAI_API_KEY'))
for temp_path, original_name in files_info:
if not is_document_embedded(original_name):
try:
documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
documents = update_metadata(documents, original_name)
documents = split_documents(documents)
if documents:
embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), collection_name=st.session_state['uploaded_collection_name'])
st.success(f"Embedded {original_name}")
else:
st.error(f"No documents found or extracted from {original_name}")
except Exception as e:
st.error(f"Failed to embed {original_name}: {str(e)}")
else:
st.info(f"{original_name} is already embedded.")
def handle_query(query):
"""Retrieve answers based on the query."""
try:
answer = retrieve_documents(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
return answer or "No relevant answer found."
except Exception as e:
return f"Error processing the query: {str(e)}"
def handle_uploaded_docs_query(query, collection_name):
"""Retrieve answers from the uploaded documents collection."""
try:
answer = retrieve_documents_from_collection(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), collection_name)
return answer or "No relevant answer found."
except Exception as e:
return f"Error processing the query: {str(e)}"
def delete_collection(collection_name, qdrant_url, qdrant_api_key):
"""Delete a Qdrant collection."""
client = setup_qdrant_client(qdrant_url, qdrant_api_key)
try:
client.delete_collection(collection_name=collection_name)
except Exception as e:
print("Failed to delete collection:", e)
if __name__ == "__main__":
main()
|