File size: 5,826 Bytes
aa8e01a 729d217 aa8e01a 729d217 aa8e01a |
|
import os
import gradio as gr
from dotenv import load_dotenv
from rag_system import load_retrieval_qa_chain, get_answer, update_embeddings
import json
import re
from PyPDF2 import PdfReader
from PIL import Image
import io
from pydantic_settings import BaseSettings
# Load environment variables
load_dotenv()
# Set OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = openai_api_key
# Ensure the static directory exists
static_directory = "static"
if not os.path.exists(static_directory):
os.makedirs(static_directory)
# PDF utility functions
def get_pdf_page_count(file_path):
with open(file_path, 'rb') as file:
pdf = PdfReader(file)
return len(pdf.pages)
def render_pdf_page(file_path, page_num):
import fitz # PyMuPDF
doc = fitz.open(file_path)
page = doc.load_page(page_num - 1) # page numbers start from 0
pix = page.get_pixmap()
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
return img
# Load PDF data
def load_pdf_data():
pdf_data = {}
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
for pdf_file in pdf_files:
file_path = f"./documents/{pdf_file}"
pdf_data[pdf_file] = {
'path': file_path,
'num_pages': get_pdf_page_count(file_path)
}
return pdf_data
# Update embeddings with new documents
update_embeddings()
# Load vector store and PDF data
qa_chain = load_retrieval_qa_chain()
pdf_data = load_pdf_data()
def pdf_viewer_interface(pdf_state, page_number, action=None, page_input=None):
selected_pdf = pdf_state['selected_pdf']
current_page = page_number
max_pages = pdf_data[selected_pdf]['num_pages']
if action == "prev":
current_page = max(1, current_page - 1)
elif action == "next":
current_page = min(max_pages, current_page + 1)
elif page_input is not None:
try:
current_page = int(page_input)
current_page = max(1, min(current_page, max_pages))
except ValueError:
pass
pdf_state['page_number'] = current_page
pdf_path = pdf_data[selected_pdf]['path']
img = render_pdf_page(pdf_path, current_page)
return img, current_page, str(current_page)
def chat_interface(user_input, chat_history, pdf_state):
chat_history_list = [item for sublist in chat_history for item in sublist]
response = get_answer(qa_chain, user_input, chat_history_list)
full_response = response["answer"]
sources = response["sources"]
chat_history.append((user_input, full_response))
return chat_history, sources
def handle_source_click(evt: gr.SelectData, sources, pdf_state, page_number):
index = evt.index[0] if isinstance(evt.index, list) else evt.index
if index >= len(sources):
return None, pdf_state, page_number, ""
source = sources[index]
file_name, page_str = source.split(' (Page ')
page_str = page_str.rstrip(')')
page = int(page_str)
if file_name not in pdf_data:
return None, pdf_state, page_number, ""
pdf_state['selected_pdf'] = file_name
pdf_state['page_number'] = page
pdf_path = pdf_data[file_name]['path']
img = render_pdf_page(pdf_path, page)
return img, pdf_state, page, str(page)
with gr.Blocks() as demo:
initial_pdf = list(pdf_data.keys())[0]
pdf_state = gr.State({'selected_pdf': initial_pdf, 'page_number': 1})
sources = gr.State([])
page_number = gr.State(1)
with gr.Row():
with gr.Column(scale=3):
chat_history = gr.State([])
chatbot = gr.Chatbot()
user_input = gr.Textbox(show_label=False, placeholder="Enter your question...")
source_list = gr.Dataframe(
headers=["Source", "Page"],
datatype=["str", "number"],
row_count=4,
col_count=2,
interactive=False,
label="Sources"
)
with gr.Column(scale=2):
pdf_dropdown = gr.Dropdown(choices=list(pdf_data.keys()), label="Select PDF", value=initial_pdf)
pdf_viewer = gr.Image(label="PDF Viewer", height=600)
pdf_page = gr.Number(label="Page Number", value=1)
with gr.Row():
prev_button = gr.Button("Previous Page")
next_button = gr.Button("Next Page")
user_input.submit(chat_interface, [user_input, chat_history, pdf_state], [chatbot, sources]).then(
lambda s: [[src.split(' (Page ')[0], int(src.split(' (Page ')[1].rstrip(')'))] for src in s],
inputs=[sources],
outputs=[source_list]
)
source_list.select(handle_source_click, [sources, pdf_state, page_number], [pdf_viewer, pdf_state, page_number, pdf_page])
pdf_dropdown.change(
lambda x: {'selected_pdf': x, 'page_number': 1},
inputs=[pdf_dropdown],
outputs=[pdf_state]
).then(
pdf_viewer_interface,
inputs=[pdf_state, gr.State(1)],
outputs=[pdf_viewer, page_number, pdf_page]
)
prev_button.click(
pdf_viewer_interface,
inputs=[pdf_state, page_number, gr.State("prev")],
outputs=[pdf_viewer, page_number, pdf_page]
)
next_button.click(
pdf_viewer_interface,
inputs=[pdf_state, page_number, gr.State("next")],
outputs=[pdf_viewer, page_number, pdf_page]
)
pdf_page.submit(
pdf_viewer_interface,
inputs=[pdf_state, page_number, gr.State(None), pdf_page],
outputs=[pdf_viewer, page_number, pdf_page]
)
chatbot.select(handle_source_click, [sources, pdf_state, page_number], [pdf_viewer, pdf_state, page_number, pdf_page])
if __name__ == "__main__":
demo.launch() |