File size: 5,826 Bytes
aa8e01a 729d217 aa8e01a 729d217 aa8e01a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import os
import gradio as gr
from dotenv import load_dotenv
from rag_system import load_retrieval_qa_chain, get_answer, update_embeddings
import json
import re
from PyPDF2 import PdfReader
from PIL import Image
import io
from pydantic_settings import BaseSettings
# Load environment variables
load_dotenv()
# Set OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = openai_api_key
# Ensure the static directory exists
static_directory = "static"
if not os.path.exists(static_directory):
os.makedirs(static_directory)
# PDF utility functions
def get_pdf_page_count(file_path):
with open(file_path, 'rb') as file:
pdf = PdfReader(file)
return len(pdf.pages)
def render_pdf_page(file_path, page_num):
import fitz # PyMuPDF
doc = fitz.open(file_path)
page = doc.load_page(page_num - 1) # page numbers start from 0
pix = page.get_pixmap()
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
return img
# Load PDF data
def load_pdf_data():
pdf_data = {}
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
for pdf_file in pdf_files:
file_path = f"./documents/{pdf_file}"
pdf_data[pdf_file] = {
'path': file_path,
'num_pages': get_pdf_page_count(file_path)
}
return pdf_data
# Update embeddings with new documents
update_embeddings()
# Load vector store and PDF data
qa_chain = load_retrieval_qa_chain()
pdf_data = load_pdf_data()
def pdf_viewer_interface(pdf_state, page_number, action=None, page_input=None):
selected_pdf = pdf_state['selected_pdf']
current_page = page_number
max_pages = pdf_data[selected_pdf]['num_pages']
if action == "prev":
current_page = max(1, current_page - 1)
elif action == "next":
current_page = min(max_pages, current_page + 1)
elif page_input is not None:
try:
current_page = int(page_input)
current_page = max(1, min(current_page, max_pages))
except ValueError:
pass
pdf_state['page_number'] = current_page
pdf_path = pdf_data[selected_pdf]['path']
img = render_pdf_page(pdf_path, current_page)
return img, current_page, str(current_page)
def chat_interface(user_input, chat_history, pdf_state):
chat_history_list = [item for sublist in chat_history for item in sublist]
response = get_answer(qa_chain, user_input, chat_history_list)
full_response = response["answer"]
sources = response["sources"]
chat_history.append((user_input, full_response))
return chat_history, sources
def handle_source_click(evt: gr.SelectData, sources, pdf_state, page_number):
index = evt.index[0] if isinstance(evt.index, list) else evt.index
if index >= len(sources):
return None, pdf_state, page_number, ""
source = sources[index]
file_name, page_str = source.split(' (Page ')
page_str = page_str.rstrip(')')
page = int(page_str)
if file_name not in pdf_data:
return None, pdf_state, page_number, ""
pdf_state['selected_pdf'] = file_name
pdf_state['page_number'] = page
pdf_path = pdf_data[file_name]['path']
img = render_pdf_page(pdf_path, page)
return img, pdf_state, page, str(page)
with gr.Blocks() as demo:
initial_pdf = list(pdf_data.keys())[0]
pdf_state = gr.State({'selected_pdf': initial_pdf, 'page_number': 1})
sources = gr.State([])
page_number = gr.State(1)
with gr.Row():
with gr.Column(scale=3):
chat_history = gr.State([])
chatbot = gr.Chatbot()
user_input = gr.Textbox(show_label=False, placeholder="Enter your question...")
source_list = gr.Dataframe(
headers=["Source", "Page"],
datatype=["str", "number"],
row_count=4,
col_count=2,
interactive=False,
label="Sources"
)
with gr.Column(scale=2):
pdf_dropdown = gr.Dropdown(choices=list(pdf_data.keys()), label="Select PDF", value=initial_pdf)
pdf_viewer = gr.Image(label="PDF Viewer", height=600)
pdf_page = gr.Number(label="Page Number", value=1)
with gr.Row():
prev_button = gr.Button("Previous Page")
next_button = gr.Button("Next Page")
user_input.submit(chat_interface, [user_input, chat_history, pdf_state], [chatbot, sources]).then(
lambda s: [[src.split(' (Page ')[0], int(src.split(' (Page ')[1].rstrip(')'))] for src in s],
inputs=[sources],
outputs=[source_list]
)
source_list.select(handle_source_click, [sources, pdf_state, page_number], [pdf_viewer, pdf_state, page_number, pdf_page])
pdf_dropdown.change(
lambda x: {'selected_pdf': x, 'page_number': 1},
inputs=[pdf_dropdown],
outputs=[pdf_state]
).then(
pdf_viewer_interface,
inputs=[pdf_state, gr.State(1)],
outputs=[pdf_viewer, page_number, pdf_page]
)
prev_button.click(
pdf_viewer_interface,
inputs=[pdf_state, page_number, gr.State("prev")],
outputs=[pdf_viewer, page_number, pdf_page]
)
next_button.click(
pdf_viewer_interface,
inputs=[pdf_state, page_number, gr.State("next")],
outputs=[pdf_viewer, page_number, pdf_page]
)
pdf_page.submit(
pdf_viewer_interface,
inputs=[pdf_state, page_number, gr.State(None), pdf_page],
outputs=[pdf_viewer, page_number, pdf_page]
)
chatbot.select(handle_source_click, [sources, pdf_state, page_number], [pdf_viewer, pdf_state, page_number, pdf_page])
if __name__ == "__main__":
demo.launch() |