import gradio as gr import requests import os import markdownify import fitz # PyMuPDF from langchain.text_splitter import RecursiveCharacterTextSplitter import pandas as pd import random from gretel_client import Gretel from gretel_client.config import GretelClientConfigurationError # Directory for saving processed PDFs output_dir = 'processed_pdfs' os.makedirs(output_dir, exist_ok=True) # Function to download and convert a PDF to text def pdf_to_text(pdf_path): pdf_document = fitz.open(pdf_path) text = '' for page_num in range(pdf_document.page_count): page = pdf_document.load_page(page_num) text += page.get_text() return text # Function to split text into chunks def split_text_into_chunks(text, chunk_size=25, chunk_overlap=5, min_chunk_chars=50): text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap) chunks = text_splitter.split_text(text) return [chunk for chunk in chunks if len(chunk) >= min_chunk_chars] # Function to save chunks to files def save_chunks(file_id, chunks, output_dir): for i, chunk in enumerate(chunks): chunk_filename = f"{file_id}_chunk_{i+1}.md" chunk_path = os.path.join(output_dir, chunk_filename) with open(chunk_path, 'w') as file: file.write(chunk) # Function to read chunks from files def read_chunks_from_files(output_dir): pdf_chunks_dict = {} for filename in os.listdir(output_dir): if filename.endswith('.md') and 'chunk' in filename: file_id = filename.split('_chunk_')[0] chunk_path = os.path.join(output_dir, filename) with open(chunk_path, 'r') as file: chunk = file.read() if file_id not in pdf_chunks_dict: pdf_chunks_dict[file_id] = [] pdf_chunks_dict[file_id].append(chunk) return pdf_chunks_dict def process_pdfs(uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, direction): selected_pdfs = [] if use_example: example_file_url = "https://gretel-datasets.s3.us-west-2.amazonaws.com/rag/GDPR_2016.pdf" pdf_path = os.path.join(output_dir, example_file_url.split('/')[-1]) if not os.path.exists(pdf_path): response = requests.get(example_file_url) with open(pdf_path, 'wb') as file: file.write(response.content) selected_pdfs = [pdf_path] elif uploaded_files is not None: for uploaded_file in uploaded_files: pdf_path = os.path.join(output_dir, uploaded_file.name) selected_pdfs.append(pdf_path) else: chunk_text = "No PDFs processed" return None, 0, chunk_text, None pdf_chunks_dict = {} for pdf_path in selected_pdfs: text = pdf_to_text(pdf_path) markdown_text = markdownify.markdownify(text) file_id = os.path.splitext(os.path.basename(pdf_path))[0] markdown_path = os.path.join(output_dir, f"{file_id}.md") with open(markdown_path, 'w') as file: file.write(markdown_text) chunks = split_text_into_chunks(markdown_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap, min_chunk_chars=min_chunk_chars) save_chunks(file_id, chunks, output_dir) pdf_chunks_dict[file_id] = chunks file_id = os.path.splitext(os.path.basename(selected_pdfs[0]))[0] chunks = pdf_chunks_dict.get(file_id, []) current_chunk += direction if current_chunk < 0: current_chunk = 0 elif current_chunk >= len(chunks): current_chunk = len(chunks) - 1 chunk_text = chunks[current_chunk] if chunks else "No chunks available." # use_example_update = gr.update( # value=False, # interactive=uploaded_files is None or len(uploaded_files) == 0 # ) return pdf_chunks_dict, selected_pdfs, chunk_text, current_chunk#, use_example_update def show_chunks(pdf_chunks_dict, selected_pdfs, current_chunk, direction): if selected_pdfs: file_id = os.path.splitext(os.path.basename(selected_pdfs[0]))[0] chunks = pdf_chunks_dict.get(file_id, []) current_chunk += direction if current_chunk < 0: current_chunk = 0 elif current_chunk >= len(chunks): current_chunk = len(chunks) - 1 chunk_text = chunks[current_chunk] if chunks else "No chunks available." return chunk_text, current_chunk else: return "No PDF processed.", 0 # Validate API key and return button state def check_api_key(api_key): try: Gretel(api_key=api_key, validate=True, clear=True) is_valid = True status_message = "Valid" except GretelClientConfigurationError: is_valid = False status_message = "Invalid" return gr.update(interactive=is_valid), status_message def generate_synthetic_records(api_key, pdf_chunks_dict, num_records): gretel = Gretel(api_key=api_key, validate=True, clear=True) navigator = gretel.factories.initialize_inference_api("navigator") INTRO_PROMPT = "From the source text below, create a dataset with the following columns:\n" COLUMN_DETAILS = ( "* `topic`: Select a topic relevant for the given source text.\n" "* `user_profile`: The complexity level of the question and truth, categorized into beginner, intermediate, and expert.\n" " - Beginner users are about building foundational knowledge about the product and ask about basic features, benefits, and uses of the product.\n" " - Intermediate users have a deep understanding of the product and are focusing on practical applications, comparisons with other products, and intermediate-level features and benefits.\n" " - Expert users demonstrate in-depth technical knowledge, strategic application, and advanced troubleshooting. This level is for those who need to know the product inside and out, possibly for roles in sales, technical support, or product development.\n" "* `question`: Ask a set of unique questions related to the topic that a user might ask. " "Questions should be relatively complex and specific enough to be addressed in a short answer.\n" "* `answer`: Respond to the question with a clear, textbook quality answer that provides relevant details to fully address the question.\n" "* `context`: Copy the exact sentence(s) from the source text and surrounding details from where the answer can be derived.\n" ) PROMPT = INTRO_PROMPT + COLUMN_DETAILS GENERATE_PARAMS = { "temperature": 0.7, "top_p": 0.9, "top_k": 40 } df_in = pd.DataFrame() documents = list(pdf_chunks_dict.keys()) all_chunks = [(doc, chunk) for doc in documents for chunk in pdf_chunks_dict[doc]] for _ in range(num_records): doc, chunk = random.choice(all_chunks) df_doc = pd.DataFrame({'document': [doc], 'text': [chunk]}) df_in = pd.concat([df_in, df_doc], ignore_index=True) df = navigator.edit(PROMPT, seed_data=df_in, **GENERATE_PARAMS) df = df.drop(columns=['text']) return gr.update(value=df, visible=True) # CSS styling to center the logo and prevent right-click download css = """ """ # HTML content to include the logo html_content = f""" {css}
""" # Gradio interface with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=3): # gr.Markdown("# Upload PDFs") # gr.Image("gretel_logo.svg", elem_id="logo", show_label=False) gr.HTML(html_content) with gr.Tab("Upload PDF"): use_example = gr.Checkbox(label="Continue with Example PDF", value=False, interactive=True) uploaded_files = gr.File(label="Upload your PDF files", file_count="multiple") # if uploaded_files: # use_example = gr.Checkbox(label="Continue with Example PDF", value=False, interactive=False) chunk_size = gr.Slider(label="Chunk Size (tokens)", minimum=10, maximum=1500, step=10, value=500) chunk_overlap = gr.Slider(label="Chunk Overlap (tokens)", minimum=0, maximum=500, step=5, value=100) min_chunk_chars = gr.Slider(label="Minimum Chunk Characters", minimum=10, maximum=2500, step=10, value=750) process_button = gr.Button("Process PDFs") pdf_chunks_dict = gr.State() selected_pdfs = gr.State() current_chunk = gr.State(value=0) chunk_text = gr.Textbox(label="Chunk Text", lines=10) def toggle_use_example(file_list): return gr.update( value = False, interactive=file_list is None or len(file_list) == 0 ) uploaded_files.change( toggle_use_example, inputs=[uploaded_files], outputs=[use_example] ) process_button.click( process_pdfs, inputs=[uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, gr.State(0)], outputs=[pdf_chunks_dict, selected_pdfs, chunk_text, current_chunk] ) with gr.Row(): prev_button = gr.Button("Previous Chunk", scale=1) next_button = gr.Button("Next Chunk", scale=1) prev_button.click( show_chunks, inputs=[pdf_chunks_dict, selected_pdfs, current_chunk, gr.State(-1)], outputs=[chunk_text, current_chunk] ) next_button.click( show_chunks, inputs=[pdf_chunks_dict, selected_pdfs, current_chunk, gr.State(1)], outputs=[chunk_text, current_chunk] ) with gr.Column(scale=7): gr.Markdown("# Generate Question-Answer pairs") with gr.Row(): api_key_input = gr.Textbox(label="API Key", type="password", placeholder="Enter your API key", scale=2) validate_status = gr.Textbox(label="Validation Status", interactive=False, scale=1) # User-specific settings num_records = gr.Number(label="Number of Records", value=10) generate_button = gr.Button("Generate Synthetic Records", interactive=False) # Validate API key on input change and update button interactivity api_key_input.change( fn=check_api_key, inputs=[api_key_input], outputs=[generate_button, validate_status] ) output_df = gr.Dataframe(headers=["document", "topic", "user_profile", "question", "answer", "context"], wrap=True, visible=True) generate_button.click( fn=generate_synthetic_records, inputs=[api_key_input, pdf_chunks_dict, num_records], outputs=[output_df] ) demo.launch()