taaha3244 commited on
Commit
17161a6
1 Parent(s): 4bb8487

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -33
app.py CHANGED
@@ -1,29 +1,41 @@
1
  import os
2
  import tempfile
 
3
  import streamlit as st
4
  from dotenv import load_dotenv
5
- from main import (
6
- load_and_split_documents, summarize_documents, embed_documents_into_qdrant,
7
- retrieve_documents, is_document_embedded, load_documents, split_documents,
8
- update_metadata, load_documents_OCR
9
- )
 
 
 
 
 
10
 
11
  load_dotenv()
12
 
13
  def main():
14
  st.sidebar.title("PDF Management")
15
  uploaded_files = st.sidebar.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
16
- model_name = st.sidebar.selectbox("Choose your model:", ["gpt-3.5-turbo", "gpt-4-turbo"]) # Model selection
17
- use_ocr = st.sidebar.checkbox("Use OCR for document processing")
18
-
19
- if st.sidebar.button('Add Uploaded Documents in Q&A'):
20
- if uploaded_files:
 
 
 
 
 
21
  files_info = save_uploaded_files(uploaded_files)
22
- embed_documents(files_info, model_name, use_ocr)
23
 
24
  pages = {
25
  "Lex Document Summarization": page_summarization,
26
- "Lex Q&A": page_qna
 
27
  }
28
 
29
  st.sidebar.title("Page Navigation")
@@ -35,10 +47,10 @@ def main():
35
 
36
  # Call the page function based on the user selection
37
  if page:
38
- pages[page](uploaded_files, model_name, use_ocr)
39
 
40
  def save_uploaded_files(uploaded_files):
41
- """Save uploaded files to temporary directory and return their file paths along with original filenames."""
42
  files_info = []
43
  for uploaded_file in uploaded_files:
44
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmpfile:
@@ -46,7 +58,7 @@ def save_uploaded_files(uploaded_files):
46
  files_info.append((tmpfile.name, uploaded_file.name))
47
  return files_info
48
 
49
- def page_summarization(uploaded_files, model_name, use_ocr):
50
  """Page for document summarization."""
51
  st.title("Lex Document Summarization")
52
  if uploaded_files:
@@ -56,42 +68,84 @@ def page_summarization(uploaded_files, model_name, use_ocr):
56
  if summary_button or (original_name in st.session_state['summaries']):
57
  with st.container():
58
  st.write(f"Summary for {original_name}:")
59
- if summary_button: # Only summarize if button is pressed
60
  try:
61
- if use_ocr:
62
- documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
63
- else:
64
- documents = load_and_split_documents(temp_path)
65
- summary = summarize_documents(model_name, documents, os.getenv('OPENAI_API_KEY'))
66
  st.session_state['summaries'][original_name] = summary # Store summary in session state
67
  except Exception as e:
68
  st.error(f"Failed to summarize {original_name}: {str(e)}")
69
  st.text_area("", value=st.session_state['summaries'][original_name], height=200, key=f"summary_{original_name}")
70
 
71
- def page_qna(uploaded_files, model_name, use_ocr):
72
  """Page for Q&A functionality."""
73
- st.title("Lex Question and Answer")
74
- user_query = st.text_area("Enter your question here:",height=300)
 
 
 
 
 
 
 
 
 
 
 
75
  if st.button('Get Answer'):
76
  if user_query:
77
- answer = handle_query(user_query, model_name)
78
  st.write(answer)
79
  else:
80
  st.error("Please enter a question to get an answer.")
 
 
 
 
 
 
 
81
 
82
- def embed_documents(files_info, model_name, use_ocr):
83
- """Function to embed documents."""
84
  for temp_path, original_name in files_info:
85
  if not is_document_embedded(original_name):
86
  try:
87
- if use_ocr:
88
- documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
89
- else:
90
- documents = load_documents(temp_path)
91
  documents = update_metadata(documents, original_name)
92
  documents = split_documents(documents)
93
  if documents:
94
  embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), 'Lex-v1')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  st.success(f"Embedded {original_name}")
96
  else:
97
  st.error(f"No documents found or extracted from {original_name}")
@@ -100,13 +154,29 @@ def embed_documents(files_info, model_name, use_ocr):
100
  else:
101
  st.info(f"{original_name} is already embedded.")
102
 
103
- def handle_query(query, model_name):
104
  """Retrieve answers based on the query."""
105
  try:
106
- answer = retrieve_documents(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), model_name)
107
  return answer or "No relevant answer found."
108
  except Exception as e:
109
  return f"Error processing the query: {str(e)}"
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  if __name__ == "__main__":
112
  main()
 
1
  import os
2
  import tempfile
3
+ import uuid
4
  import streamlit as st
5
  from dotenv import load_dotenv
6
+ from qdrant_client import models
7
+ from langchain_community.vectorstores import Qdrant
8
+
9
+
10
+ from utils import setup_openai_embeddings,setup_qdrant_client,delete_collection,is_document_embedded
11
+ from embed import embed_documents_into_qdrant
12
+ from preprocess import split_documents,update_metadata,load_documents_OCR
13
+ from retrieve import retrieve_documents,retrieve_documents_from_collection
14
+ from summarize import summarize_documents
15
+
16
 
17
  load_dotenv()
18
 
19
  def main():
20
  st.sidebar.title("PDF Management")
21
  uploaded_files = st.sidebar.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
22
+
23
+ if 'uploaded_collection_name' not in st.session_state:
24
+ st.session_state['uploaded_collection_name'] = None
25
+
26
+ if uploaded_files:
27
+ if st.sidebar.button("Add Docs to Data Bank"):
28
+ files_info = save_uploaded_files(uploaded_files)
29
+ embed_documents_to_data_bank(files_info)
30
+
31
+ if st.sidebar.button("Add Docs to Current Chat"):
32
  files_info = save_uploaded_files(uploaded_files)
33
+ add_docs_to_current_chat(files_info)
34
 
35
  pages = {
36
  "Lex Document Summarization": page_summarization,
37
+ "Chat with Data Bank": page_qna,
38
+ "Chat with Uploaded Docs": page_chat_with_uploaded_docs
39
  }
40
 
41
  st.sidebar.title("Page Navigation")
 
47
 
48
  # Call the page function based on the user selection
49
  if page:
50
+ pages[page](uploaded_files)
51
 
52
  def save_uploaded_files(uploaded_files):
53
+ """Save uploaded files to a temporary directory and return their file paths along with original filenames."""
54
  files_info = []
55
  for uploaded_file in uploaded_files:
56
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmpfile:
 
58
  files_info.append((tmpfile.name, uploaded_file.name))
59
  return files_info
60
 
61
+ def page_summarization(uploaded_files):
62
  """Page for document summarization."""
63
  st.title("Lex Document Summarization")
64
  if uploaded_files:
 
68
  if summary_button or (original_name in st.session_state['summaries']):
69
  with st.container():
70
  st.write(f"Summary for {original_name}:")
71
+ if summary_button: # Only summarize if the button is pressed
72
  try:
73
+ documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
74
+ summary = summarize_documents(documents, os.getenv('OPENAI_API_KEY'))
 
 
 
75
  st.session_state['summaries'][original_name] = summary # Store summary in session state
76
  except Exception as e:
77
  st.error(f"Failed to summarize {original_name}: {str(e)}")
78
  st.text_area("", value=st.session_state['summaries'][original_name], height=200, key=f"summary_{original_name}")
79
 
80
+ def page_qna(uploaded_files):
81
  """Page for Q&A functionality."""
82
+ st.title("Chat with Data Bank")
83
+ user_query = st.text_area("Enter your question here:", height=300)
84
+ if st.button('Get Answer'):
85
+ if user_query:
86
+ answer = handle_query(user_query)
87
+ st.write(answer)
88
+ else:
89
+ st.error("Please enter a question to get an answer.")
90
+
91
+ def page_chat_with_uploaded_docs(uploaded_files):
92
+ """Page for chatting with uploaded documents."""
93
+ st.title("Chat with Uploaded Documents")
94
+ user_query = st.text_area("Enter your question here:", height=300)
95
  if st.button('Get Answer'):
96
  if user_query:
97
+ answer = handle_uploaded_docs_query(user_query, st.session_state['uploaded_collection_name'])
98
  st.write(answer)
99
  else:
100
  st.error("Please enter a question to get an answer.")
101
+
102
+ if st.session_state['uploaded_collection_name']:
103
+ if st.button('Delete Embedded Collection'):
104
+ collection_name = st.session_state['uploaded_collection_name']
105
+ delete_collection(collection_name, os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
106
+ st.session_state['uploaded_collection_name'] = None
107
+ st.success(f"Deleted collection {collection_name}")
108
 
109
+ def embed_documents_to_data_bank(files_info):
110
+ """Function to embed documents into the data bank."""
111
  for temp_path, original_name in files_info:
112
  if not is_document_embedded(original_name):
113
  try:
114
+ documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
 
 
 
115
  documents = update_metadata(documents, original_name)
116
  documents = split_documents(documents)
117
  if documents:
118
  embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), 'Lex-v1')
119
+ st.success(f"Embedded {original_name} into Data Bank")
120
+ else:
121
+ st.error(f"No documents found or extracted from {original_name}")
122
+ except Exception as e:
123
+ st.error(f"Failed to embed {original_name}: {str(e)}")
124
+ else:
125
+ st.info(f"{original_name} is already embedded.")
126
+
127
+ def add_docs_to_current_chat(files_info):
128
+ """Function to add documents to the current chat session."""
129
+ if not st.session_state['uploaded_collection_name']:
130
+ st.session_state['uploaded_collection_name'] = f"session-{uuid.uuid4()}"
131
+ client = setup_qdrant_client(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
132
+ client.create_collection(
133
+ collection_name=st.session_state['uploaded_collection_name'],
134
+ vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE)
135
+ )
136
+ else:
137
+ client = setup_qdrant_client(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
138
+
139
+ embeddings_model = setup_openai_embeddings(os.getenv('OPENAI_API_KEY'))
140
+
141
+ for temp_path, original_name in files_info:
142
+ if not is_document_embedded(original_name):
143
+ try:
144
+ documents = load_documents_OCR(temp_path, os.getenv('UNSTRUCTURED_API'))
145
+ documents = update_metadata(documents, original_name)
146
+ documents = split_documents(documents)
147
+ if documents:
148
+ embed_documents_into_qdrant(documents, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), collection_name=st.session_state['uploaded_collection_name'])
149
  st.success(f"Embedded {original_name}")
150
  else:
151
  st.error(f"No documents found or extracted from {original_name}")
 
154
  else:
155
  st.info(f"{original_name} is already embedded.")
156
 
157
+ def handle_query(query):
158
  """Retrieve answers based on the query."""
159
  try:
160
+ answer = retrieve_documents(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))
161
  return answer or "No relevant answer found."
162
  except Exception as e:
163
  return f"Error processing the query: {str(e)}"
164
 
165
+ def handle_uploaded_docs_query(query, collection_name):
166
+ """Retrieve answers from the uploaded documents collection."""
167
+ try:
168
+ answer = retrieve_documents_from_collection(query, os.getenv('OPENAI_API_KEY'), os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'), collection_name)
169
+ return answer or "No relevant answer found."
170
+ except Exception as e:
171
+ return f"Error processing the query: {str(e)}"
172
+
173
+ def delete_collection(collection_name, qdrant_url, qdrant_api_key):
174
+ """Delete a Qdrant collection."""
175
+ client = setup_qdrant_client(qdrant_url, qdrant_api_key)
176
+ try:
177
+ client.delete_collection(collection_name=collection_name)
178
+ except Exception as e:
179
+ print("Failed to delete collection:", e)
180
+
181
  if __name__ == "__main__":
182
  main()