Spaces:
Runtime error
Runtime error
changes to app.py and document_retrieval.py
Browse files- app.py +40 -40
- src/document_retrieval.py +11 -9
app.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
import logging
|
4 |
import yaml
|
5 |
import gradio as gr
|
6 |
-
import time
|
7 |
|
8 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
9 |
print(current_dir)
|
@@ -16,61 +14,61 @@ from utils.vectordb.vector_db import VectorDb
|
|
16 |
CONFIG_PATH = os.path.join(current_dir,'config.yaml')
|
17 |
PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
|
18 |
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
self.show_sources = True
|
27 |
-
self.sources_history = []
|
28 |
-
self.vectorstore = None
|
29 |
-
self.input_disabled = True
|
30 |
-
self.document_retrieval = None
|
31 |
|
32 |
-
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
def handle_userinput(user_question):
|
37 |
if user_question:
|
38 |
try:
|
39 |
-
|
40 |
-
|
41 |
-
response_time = time.time() - response_time
|
42 |
-
chat_state.chat_history.append((user_question, response["answer"]))
|
43 |
|
44 |
#sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]])
|
45 |
#sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)])
|
46 |
#state.sources_history.append(sources_text)
|
47 |
|
48 |
-
return
|
49 |
except Exception as e:
|
50 |
return f"An error occurred: {str(e)}", "" #, state.sources_history
|
51 |
-
|
|
|
|
|
52 |
|
53 |
-
def process_documents(files, save_location=None):
|
54 |
try:
|
55 |
#for doc in files:
|
56 |
_, _, text_chunks = parse_doc_universal(doc=files)
|
57 |
print(text_chunks)
|
58 |
#text_chunks = chat_state.document_retrieval.parse_doc(files)
|
59 |
-
embeddings =
|
60 |
collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
|
61 |
-
vectorstore =
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
return "Complete! You can now ask questions."
|
67 |
except Exception as e:
|
68 |
-
return f"An error occurred while processing: {str(e)}"
|
69 |
|
70 |
def reset_conversation():
|
71 |
-
|
72 |
#chat_state.sources_history = []
|
73 |
-
return
|
74 |
|
75 |
def show_selection(model):
|
76 |
return f"You selected: {model}"
|
@@ -89,7 +87,8 @@ caution_text = """⚠️ Note: depending on the size of your document, this coul
|
|
89 |
"""
|
90 |
|
91 |
with gr.Blocks() as demo:
|
92 |
-
|
|
|
93 |
gr.Markdown("# Enterprise Knowledge Retriever",
|
94 |
elem_id="title")
|
95 |
|
@@ -108,8 +107,8 @@ with gr.Blocks() as demo:
|
|
108 |
process_btn = gr.Button("🔄 Process")
|
109 |
gr.Markdown(caution_text)
|
110 |
|
111 |
-
|
112 |
-
process_btn.click(process_documents, inputs=[docs], outputs=
|
113 |
#process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output)
|
114 |
#load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output)
|
115 |
|
@@ -117,13 +116,14 @@ with gr.Blocks() as demo:
|
|
117 |
gr.Markdown("## 3️⃣ Chat with your document")
|
118 |
chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
|
119 |
msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...")
|
120 |
-
|
121 |
#show_sources = gr.Checkbox(label="Show sources", value=True)
|
122 |
sources_output = gr.Textbox(label="Sources", visible=False)
|
123 |
|
|
|
124 |
#msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output])
|
125 |
-
msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, msg])
|
126 |
-
|
127 |
#show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output)
|
128 |
|
129 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
import sys
|
|
|
3 |
import yaml
|
4 |
import gradio as gr
|
|
|
5 |
|
6 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
7 |
print(current_dir)
|
|
|
14 |
CONFIG_PATH = os.path.join(current_dir,'config.yaml')
|
15 |
PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
|
16 |
|
17 |
+
#class ChatState:
|
18 |
+
# def __init__(self):
|
19 |
+
# self.conversation = None
|
20 |
+
# self.chat_history = []
|
21 |
+
# self.show_sources = True
|
22 |
+
# self.sources_history = []
|
23 |
+
# self.vectorstore = None
|
24 |
+
# self.input_disabled = True
|
25 |
+
# self.document_retrieval = None
|
26 |
|
27 |
+
chat_history = gr.State()
|
28 |
+
chat_history = []
|
29 |
+
vectorstore = gr.State()
|
30 |
+
document_retrieval = gr.State()
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
document_retrieval = DocumentRetrieval()
|
33 |
|
34 |
+
def handle_userinput(user_question, conversation):
|
|
|
|
|
35 |
if user_question:
|
36 |
try:
|
37 |
+
response = conversation.invoke({"question": user_question})
|
38 |
+
chat_history.append((user_question, response["answer"]))
|
|
|
|
|
39 |
|
40 |
#sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]])
|
41 |
#sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)])
|
42 |
#state.sources_history.append(sources_text)
|
43 |
|
44 |
+
return chat_history, "" #, state.sources_history
|
45 |
except Exception as e:
|
46 |
return f"An error occurred: {str(e)}", "" #, state.sources_history
|
47 |
+
else:
|
48 |
+
return "An error occurred", ""
|
49 |
+
#return chat_history, "" #, state.sources_history
|
50 |
|
51 |
+
def process_documents(files, conversation, save_location=None):
|
52 |
try:
|
53 |
#for doc in files:
|
54 |
_, _, text_chunks = parse_doc_universal(doc=files)
|
55 |
print(text_chunks)
|
56 |
#text_chunks = chat_state.document_retrieval.parse_doc(files)
|
57 |
+
embeddings = document_retrieval.load_embedding_model()
|
58 |
collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
|
59 |
+
vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
|
60 |
+
#vectorstore = vectorstore
|
61 |
+
document_retrieval.init_retriever(vectorstore)
|
62 |
+
conversation = document_retrieval.get_qa_retrieval_chain()
|
63 |
+
#input_disabled = False
|
64 |
+
return conversation, "Complete! You can now ask questions."
|
65 |
except Exception as e:
|
66 |
+
return conversation, f"An error occurred while processing: {str(e)}"
|
67 |
|
68 |
def reset_conversation():
|
69 |
+
chat_history = []
|
70 |
#chat_state.sources_history = []
|
71 |
+
return chat_history, ""
|
72 |
|
73 |
def show_selection(model):
|
74 |
return f"You selected: {model}"
|
|
|
87 |
"""
|
88 |
|
89 |
with gr.Blocks() as demo:
|
90 |
+
conversation = gr.State()
|
91 |
+
|
92 |
gr.Markdown("# Enterprise Knowledge Retriever",
|
93 |
elem_id="title")
|
94 |
|
|
|
107 |
process_btn = gr.Button("🔄 Process")
|
108 |
gr.Markdown(caution_text)
|
109 |
|
110 |
+
# Preprocessing events
|
111 |
+
process_btn.click(process_documents, inputs=[docs, conversation], outputs=[conversation, setup_output])
|
112 |
#process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output)
|
113 |
#load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output)
|
114 |
|
|
|
116 |
gr.Markdown("## 3️⃣ Chat with your document")
|
117 |
chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
|
118 |
msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...")
|
119 |
+
clear_btn = gr.Button("Clear chat")
|
120 |
#show_sources = gr.Checkbox(label="Show sources", value=True)
|
121 |
sources_output = gr.Textbox(label="Sources", visible=False)
|
122 |
|
123 |
+
# Chatbot events
|
124 |
#msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output])
|
125 |
+
msg.submit(handle_userinput, inputs=[msg, conversation], outputs=[chatbot, msg])
|
126 |
+
clear_btn.click(reset_conversation, outputs=[chatbot,msg])
|
127 |
#show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output)
|
128 |
|
129 |
if __name__ == "__main__":
|
src/document_retrieval.py
CHANGED
@@ -21,7 +21,7 @@ repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
|
|
21 |
sys.path.append(kit_dir)
|
22 |
sys.path.append(repo_dir)
|
23 |
|
24 |
-
import streamlit as st
|
25 |
|
26 |
from utils.model_wrappers.api_gateway import APIGateway
|
27 |
from utils.vectordb.vector_db import VectorDb
|
@@ -30,7 +30,7 @@ from utils.visual.env_utils import get_wandb_key
|
|
30 |
CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
|
31 |
PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db')
|
32 |
|
33 |
-
load_dotenv(os.path.join(kit_dir, '.env'))
|
34 |
|
35 |
|
36 |
from utils.parsing.sambaparse import parse_doc_universal
|
@@ -153,13 +153,15 @@ class DocumentRetrieval:
|
|
153 |
return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
|
154 |
|
155 |
def set_llm(self):
|
156 |
-
if self.prod_mode:
|
157 |
-
|
158 |
-
else:
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
163 |
|
164 |
llm = APIGateway.load_llm(
|
165 |
type=self.api_info,
|
|
|
21 |
sys.path.append(kit_dir)
|
22 |
sys.path.append(repo_dir)
|
23 |
|
24 |
+
#import streamlit as st
|
25 |
|
26 |
from utils.model_wrappers.api_gateway import APIGateway
|
27 |
from utils.vectordb.vector_db import VectorDb
|
|
|
30 |
CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
|
31 |
PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db')
|
32 |
|
33 |
+
#load_dotenv(os.path.join(kit_dir, '.env'))
|
34 |
|
35 |
|
36 |
from utils.parsing.sambaparse import parse_doc_universal
|
|
|
153 |
return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
|
154 |
|
155 |
def set_llm(self):
|
156 |
+
#if self.prod_mode:
|
157 |
+
# sambanova_api_key = st.session_state.SAMBANOVA_API_KEY
|
158 |
+
#else:
|
159 |
+
# if 'SAMBANOVA_API_KEY' in st.session_state:
|
160 |
+
# sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') or st.session_state.SAMBANOVA_API_KEY
|
161 |
+
# else:
|
162 |
+
# sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
|
163 |
+
|
164 |
+
sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
|
165 |
|
166 |
llm = APIGateway.load_llm(
|
167 |
type=self.api_info,
|