File size: 9,408 Bytes
8999dd1
 
17161a6
8999dd1
df07373
17161a6
 
ace378a
17161a6
ace378a
 
17161a6
 
ace378a
8999dd1
 
 
 
 
17161a6
 
 
 
ace378a
 
 
 
 
 
17161a6
 
 
 
 
 
df07373
17161a6
8999dd1
df07373
 
ace378a
17161a6
df07373
8999dd1
df07373
 
8999dd1
df07373
 
 
 
 
 
17161a6
8999dd1
 
17161a6
8999dd1
 
 
 
 
 
 
17161a6
df07373
 
 
 
8999dd1
df07373
 
 
 
17161a6
df07373
17161a6
 
df07373
 
 
ace378a
 
 
 
df07373
17161a6
df07373
17161a6
 
 
 
 
ace378a
17161a6
 
 
ace378a
 
 
 
17161a6
 
 
 
df07373
 
17161a6
ace378a
df07373
 
 
ace378a
 
 
17161a6
 
 
 
 
 
 
df07373
17161a6
 
8999dd1
df07373
8999dd1
17161a6
df07373
 
 
 
17161a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df07373
 
 
8999dd1
 
 
 
 
17161a6
8999dd1
df07373
17161a6
df07373
 
 
8999dd1
17161a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8999dd1
ace378a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import tempfile
import uuid
import streamlit as st
from dotenv import load_dotenv
from qdrant_client import models

from utils import setup_openai_embeddings, setup_qdrant_client, delete_collection, is_document_embedded
from embed import embed_documents_into_qdrant
from preprocess import split_documents, update_metadata, load_documents_OCR
from retrieve import retrieve_documents, retrieve_documents_from_collection
from summarize import summarize_documents

# Load environment variables
load_dotenv()

def main():
    st.sidebar.title("PDF Management")
    uploaded_files = st.sidebar.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)

    if 'uploaded_collection_name' not in st.session_state:
        st.session_state['uploaded_collection_name'] = None

    # Initialize session state for storing answers if not already set
    if 'data_bank_answer' not in st.session_state:
        st.session_state['data_bank_answer'] = None
    if 'uploaded_docs_answer' not in st.session_state:
        st.session_state['uploaded_docs_answer'] = None

    if uploaded_files:
        if st.sidebar.button("Add Docs to Data Bank"):
            files_info = save_uploaded_files(uploaded_files)
            embed_documents_to_data_bank(files_info)
        
        if st.sidebar.button("Add Docs to Current Chat"):
            files_info = save_uploaded_files(uploaded_files)
            add_docs_to_current_chat(files_info)

    pages = {
        "Lex Document Summarization": page_summarization,
        "Chat with Data Bank": page_qna,
        "Chat with Uploaded Docs": page_chat_with_uploaded_docs
    }

    st.sidebar.title("Page Navigation")
    page = st.sidebar.radio("Select a page", tuple(pages.keys()))

    # Initialize session state for summarization results if not already set
    if 'summaries' not in st.session_state:
        st.session_state['summaries'] = {}

    # Call the page function based on the user selection
    if page:
        pages[page](uploaded_files)

def save_uploaded_files(uploaded_files):
    """Save uploaded files to a temporary directory and return their file paths along with original filenames."""
    files_info = []
    for uploaded_file in uploaded_files:
        with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmpfile:
            tmpfile.write(uploaded_file.getvalue())
            files_info.append((tmpfile.name, uploaded_file.name))
    return files_info

def page_summarization(uploaded_files):
    """Page for document summarization."""
    st.title("Lex Document Summarization")
    if uploaded_files:
        files_info = save_uploaded_files(uploaded_files)
        for temp_path, original_name in files_info:
            summary_button = st.button(f"Summarize {original_name}", key=original_name)
            if summary_button or (original_name in st.session_state['summaries']):
                with st.container():
                    st.write(f"Summary for {original_name}:")
                    if summary_button:  # Only summarize if the button is pressed
                        try:
                            documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
                            summary = summarize_documents(documents, os.getenv('OPENAI_API_KEY'))
                            st.session_state['summaries'][original_name] = summary  # Store summary in session state
                        except Exception as e:
                            st.error(f"Failed to summarize {original_name}: {str(e)}")
                    if original_name in st.session_state['summaries']:
                        st.text_area("", value=st.session_state['summaries'][original_name], height=1000, key=f"summary_{original_name}")
                    else:
                        st.error(f"No summary found for {original_name}. Please click the summarize button.")

def page_qna(uploaded_files):
    """Page for Q&A functionality."""
    st.title("Chat with Data Bank")
    user_query = st.text_area("Enter your question here:", height=300)
    if st.button('Get Answer'):
        if user_query:
            answer = handle_query(user_query)
            st.session_state['data_bank_answer'] = answer  # Store the answer in session state
            st.write(answer)
        else:
            st.error("Please enter a question to get an answer.")
    # Display stored answer if it exists
    if st.session_state['data_bank_answer']:
        st.write(st.session_state['data_bank_answer'])

def page_chat_with_uploaded_docs(uploaded_files):
    """Page for chatting with uploaded documents."""
    st.title("Chat with Uploaded Documents")
    user_query = st.text_area("Enter your question here:", height=300)
    if st.button('Get Answer'):
        if user_query:
            answer = handle_uploaded_docs_query(user_query, st.session_state['uploaded_collection_name'])
            st.session_state['uploaded_docs_answer'] = answer  # Store the answer in session state
            st.write(answer)
        else:
            st.error("Please enter a question to get an answer.")
    # Display stored answer if it exists
    if st.session_state['uploaded_docs_answer']:
        st.write(st.session_state['uploaded_docs_answer'])
    
    if st.session_state['uploaded_collection_name']:
        if st.button('Delete Embedded Collection'):
            collection_name = st.session_state['uploaded_collection_name']
            delete_collection(collection_name, os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
            st.session_state['uploaded_collection_name'] = None
            st.success(f"Deleted collection {collection_name}")

def embed_documents_to_data_bank(files_info):
    """Function to embed documents into the data bank."""
    for temp_path, original_name in files_info:
        if not is_document_embedded(original_name):
            try:
                documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
                documents = update_metadata(documents, original_name)
                documents = split_documents(documents)
                if documents:
                    embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), 'Lex-v1')
                    st.success(f"Embedded {original_name} into Data Bank")
                else:
                    st.error(f"No documents found or extracted from {original_name}")
            except Exception as e:
                st.error(f"Failed to embed {original_name}: {str(e)}")
        else:
            st.info(f"{original_name} is already embedded.")

def add_docs_to_current_chat(files_info):
    """Function to add documents to the current chat session."""
    if not st.session_state['uploaded_collection_name']:
        st.session_state['uploaded_collection_name'] = f"session-{uuid.uuid4()}"
        client = setup_qdrant_client(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
        client.create_collection(
            collection_name=st.session_state['uploaded_collection_name'],
            vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE)
        )
    else:
        client = setup_qdrant_client(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))

    embeddings_model = setup_openai_embeddings(os.getenv('OPENAI_API_KEY'))

    for temp_path, original_name in files_info:
        if not is_document_embedded(original_name):
            try:
                documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
                documents = update_metadata(documents, original_name)
                documents = split_documents(documents)
                if documents:
                    embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), collection_name=st.session_state['uploaded_collection_name'])
                    st.success(f"Embedded {original_name}")
                else:
                    st.error(f"No documents found or extracted from {original_name}")
            except Exception as e:
                st.error(f"Failed to embed {original_name}: {str(e)}")
        else:
            st.info(f"{original_name} is already embedded.")

def handle_query(query):
    """Retrieve answers based on the query."""
    try:
        answer = retrieve_documents(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
        return answer or "No relevant answer found."
    except Exception as e:
        return f"Error processing the query: {str(e)}"

def handle_uploaded_docs_query(query, collection_name):
    """Retrieve answers from the uploaded documents collection."""
    try:
        answer = retrieve_documents_from_collection(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), collection_name)
        return answer or "No relevant answer found."
    except Exception as e:
        return f"Error processing the query: {str(e)}"

def delete_collection(collection_name, qdrant_url, qdrant_api_key):
    """Delete a Qdrant collection."""
    client = setup_qdrant_client(qdrant_url, qdrant_api_key)
    try:
        client.delete_collection(collection_name=collection_name)
    except Exception as e:
        print("Failed to delete collection:", e)

if __name__ == "__main__":
    main()