Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import logging | |
import yaml | |
import gradio as gr | |
import time | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
print(current_dir) | |
from src.document_retrieval import DocumentRetrieval | |
from utils.visual.env_utils import env_input_fields, initialize_env_variables, are_credentials_set, save_credentials | |
from utils.parsing.sambaparse import parse_doc_universal # added Petro | |
from utils.vectordb.vector_db import VectorDb | |
CONFIG_PATH = os.path.join(current_dir,'config.yaml') | |
PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir | |
logging.basicConfig(level=logging.INFO) | |
logging.info("Gradio app is running") | |
class ChatState: | |
def __init__(self): | |
self.conversation = None | |
self.chat_history = [] | |
self.show_sources = True | |
self.sources_history = [] | |
self.vectorstore = None | |
self.input_disabled = True | |
self.document_retrieval = None | |
chat_state = ChatState() | |
chat_state.document_retrieval = DocumentRetrieval() | |
def handle_userinput(user_question): | |
if user_question: | |
try: | |
response_time = time.time() | |
response = chat_state.conversation.invoke({"question": user_question}) | |
response_time = time.time() - response_time | |
chat_state.chat_history.append((user_question, response["answer"])) | |
#sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]]) | |
#sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)]) | |
#state.sources_history.append(sources_text) | |
return chat_state.chat_history, "" #, state.sources_history | |
except Exception as e: | |
return f"An error occurred: {str(e)}", "" #, state.sources_history | |
return chat_state.chat_history, "" #, state.sources_history | |
def process_documents(files, save_location=None): | |
try: | |
#for doc in files: | |
_, _, text_chunks = parse_doc_universal(doc=files) | |
print(text_chunks) | |
#text_chunks = chat_state.document_retrieval.parse_doc(files) | |
embeddings = chat_state.document_retrieval.load_embedding_model() | |
collection_name = 'ekr_default_collection' if not config['prod_mode'] else None | |
vectorstore = chat_state.document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name) | |
chat_state.vectorstore = vectorstore | |
chat_state.document_retrieval.init_retriever(vectorstore) | |
chat_state.conversation = chat_state.document_retrieval.get_qa_retrieval_chain() | |
chat_state.input_disabled = False | |
return "Complete! You can now ask questions." | |
except Exception as e: | |
return f"An error occurred while processing: {str(e)}" | |
def reset_conversation(): | |
chat_state.chat_history = [] | |
#chat_state.sources_history = [] | |
return chat_state.chat_history, "" | |
def show_selection(model): | |
return f"You selected: {model}" | |
# Read config file | |
with open(CONFIG_PATH, 'r') as yaml_file: | |
config = yaml.safe_load(yaml_file) | |
prod_mode = config.get('prod_mode', False) | |
default_collection = 'ekr_default_collection' | |
# Load env variables | |
initialize_env_variables(prod_mode) | |
caution_text = """⚠️ Note: depending on the size of your document, this could take several minutes. | |
""" | |
with gr.Blocks() as demo: | |
#gr.Markdown("# SambaNova Analyst Assistant") # title | |
gr.Markdown("# Enterprise Knowledge Retriever", | |
elem_id="title") | |
gr.Markdown("Powered by LLama3.1-8B-Instruct on SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).") | |
api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability") | |
# Step 1: Add PDF file | |
gr.Markdown("## 1️⃣ Upload PDF") | |
docs = gr.File(label="Add PDF file (single)", file_types=["pdf"], file_count="single") | |
# Step 2: Process PDF file | |
gr.Markdown(("## 2️⃣ Process document and create vector store")) | |
db_btn = gr.Radio(["ChromaDB"], label="Vector store type", value = "ChromaDB", type="index", info="Choose your vector store") | |
setup_output = gr.Textbox(label="Processing status", visible=True, value="None") | |
process_btn = gr.Button("🔄 Process") | |
gr.Markdown(caution_text) | |
process_btn.click(process_documents, inputs=[docs], outputs=setup_output, concurrency_limit=10) | |
#process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output) | |
#load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output) | |
# Step 3: Chat with your data | |
gr.Markdown("## 3️⃣ Chat with your document") | |
chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True) | |
msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...") | |
clear = gr.Button("Clear chat") | |
#show_sources = gr.Checkbox(label="Show sources", value=True) | |
sources_output = gr.Textbox(label="Sources", visible=False) | |
#msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output]) | |
msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, msg]) | |
clear.click(reset_conversation, outputs=[chatbot,msg]) | |
#show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output) | |
if __name__ == "__main__": | |
demo.launch() | |