import argparse import json import logging import os import re import shutil import time from concurrent.futures import ThreadPoolExecutor from io import BytesIO from typing import Optional from urllib.parse import urlparse import layoutparser as lp import openai import pytesseract import requests from dotenv import load_dotenv from pdf2image import convert_from_bytes from pydantic import BaseModel, ConfigDict from create_assistant import create_assistant load_dotenv() logging.basicConfig(handlers=[logging.StreamHandler()], level=logging.INFO) logger = logging.getLogger(__name__) class Block(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) block: lp.elements.base.BaseLayoutElement page_index: int class CaptionedBlock(Block): model_config = ConfigDict(arbitrary_types_allowed=True) caption: lp.elements.base.BaseLayoutElement def get_blocks_and_texts(layouts: list[lp.Layout]) -> tuple[list[Block], list[Block]]: blocks = [] texts = [] for i, layout in enumerate(layouts): for block in layout: if block.type in ["Table", "Figure"]: # Check if the current block overlaps with any existing block for existing_block in blocks: if existing_block.page_index != i: # If the blocks are not on the same page, skip the overlap check continue overlap_area = existing_block.block.intersect(block).area overlap_ratio = overlap_area / block.area if overlap_ratio > 0.5: # If the current block overlaps with an existing block by more than 50% # Check which block is the "superset" block if block.area > existing_block.block.area: # If the current block is larger, replace the existing block with the current block blocks.remove(existing_block) blocks.append(Block(block=block, page_index=i)) # If the existing block is larger or equal, skip the current block break else: # If the current block does not overlap significantly with any existing block, add it to the list blocks.append(Block(block=block, page_index=i)) elif block.type == "Text": texts.append(Block(block=block, page_index=i)) return blocks, texts def caption_blocks(blocks: list[Block], texts: list[Block]) -> list[CaptionedBlock]: captioned_blocks = [] # Find the closest text block to the top and bottom of the figure/table block for block in blocks: block_bottom_center = ( (block.block.block.x_1 + block.block.block.x_2) / 2, block.block.block.y_2, ) block_top_center = ( (block.block.block.x_1 + block.block.block.x_2) / 2, block.block.block.y_1, ) closest_text = None closest_distance = float("inf") for text in texts: if text.page_index != block.page_index: continue text_top_center = ( (text.block.block.x_1 + text.block.block.x_2) / 2, text.block.block.y_1, ) text_bottom_center = ( (text.block.block.x_1 + text.block.block.x_2) / 2, text.block.block.y_2, ) distance_to_top = ( (block_bottom_center[0] - text_top_center[0]) ** 2 + (block_bottom_center[1] - text_top_center[1]) ** 2 ) ** 0.5 distance_to_bottom = ( (block_top_center[0] - text_bottom_center[0]) ** 2 + (block_top_center[1] - text_bottom_center[1]) ** 2 ) ** 0.5 # Reduce `distance_to_top` by 25% to bias towards picking bottom captions distance = min(distance_to_top * 0.75, distance_to_bottom) if distance < closest_distance: closest_distance = distance closest_text = text if closest_text is not None: captioned_blocks.append( CaptionedBlock( block=block.block, caption=closest_text.block, page_index=block.page_index, ) ) return captioned_blocks def combine_blocks(captioned_block, pages): # Combine block and caption together x_1 = min(captioned_block.block.block.x_1, captioned_block.caption.block.x_1) y_1 = min(captioned_block.block.block.y_1, captioned_block.caption.block.y_1) x_2 = max(captioned_block.block.block.x_2, captioned_block.caption.block.x_2) y_2 = max(captioned_block.block.block.y_2, captioned_block.caption.block.y_2) return pages[captioned_block.page_index].crop((x_1, y_1, x_2, y_2)) def process_captioned_block(captioned_block, pages, base_path): combined_image = combine_blocks(captioned_block, pages) # Convert the PIL Image object to base64 buffered = BytesIO() combined_image.save(buffered, format="JPEG") # Convert the PIL Image object to a string for caption caption_image = pages[captioned_block.page_index].crop( ( captioned_block.caption.block.x_1, captioned_block.caption.block.y_1, captioned_block.caption.block.x_2, captioned_block.caption.block.y_2, ) ) caption_text = pytesseract.image_to_string(caption_image) figures_path = os.path.join(base_path, "figures") os.makedirs(figures_path, exist_ok=True) # Convert the caption text to snake case alpha numeric and truncate, then add .jpg to it img_name = re.sub("[^0-9a-zA-Z]+", "_", caption_text)[:30] + ".jpg" img_path = os.path.join(figures_path, img_name) with open(img_path, "wb") as f: f.write(buffered.getvalue()) return {"image": f"figures/{img_name}", "caption": caption_text} def process_pdf(content: bytes, model: lp.models.Detectron2LayoutModel, base_path: str): pages = convert_from_bytes(content) logger.info("PDF converted to images") with ThreadPoolExecutor(max_workers=16) as executor: layouts = list(executor.map(model.detect, pages)) logger.info("Layout detection completed") blocks, texts = get_blocks_and_texts(layouts) logger.info("Blocks and texts extracted") captioned_blocks = caption_blocks(blocks, texts) logger.info("Captioning completed") with ThreadPoolExecutor(max_workers=16) as executor: results = list( executor.map( lambda captioned_block: process_captioned_block( captioned_block, pages, base_path ), captioned_blocks, ) ) return results def wait_on_run(run, thread, client: openai.OpenAI): while run.status == "queued" or run.status == "in_progress": run = client.beta.threads.runs.retrieve( thread_id=thread.id, run_id=run.id, ) time.sleep(0.5) return run def generate_thread_content( pdf_path: str, results: dict, client: openai.OpenAI, assistant_id: str ): with open(pdf_path, "rb") as f: pdf_file = client.files.create(file=f, purpose="assistants") try: thread = client.beta.threads.create() message = client.beta.threads.messages.create( thread_id=thread.id, role="user", content=f"{json.dumps(results)}\n\nCreate a thread for this. Your answer must be in JSON, media links should be from the local paths above.", file_ids=[pdf_file.id], ) run = client.beta.threads.runs.create( thread_id=thread.id, assistant_id=assistant_id ) run = wait_on_run(run, thread, client) messages = client.beta.threads.messages.list( thread_id=thread.id, order="asc", after=message.id ) # TODO: OpenAI can return no new messages somehow (might be a bug, the run completes succesfully but no new messages are listed in the thread), catch this and throw error if not messages.data or not messages.data[0].content: raise ValueError("Unexpected empty response from OpenAI. Please try again.") except Exception as e: logger.error(f"Failed to generate thread content: {e}") raise finally: # Delete uploaded PDF file try: client.files.delete(file_id=pdf_file.id) except Exception as e: logger.error(f"Failed to delete file: {e}") # Extract JSON content from the message message_content = messages.data[0].content[0].text.value json_content = re.search(r"(```json\n)(.*?)(\n```)", message_content, re.DOTALL) if json_content is None: json_content = re.search(r"(```\n)(.*?)(\n```)", message_content, re.DOTALL) if json_content is not None: json_content = json_content.group(2) try: paper_thread = json.loads(json_content) except (json.JSONDecodeError, TypeError): raise ValueError( "The thread generated by OpenAI was not in the expected JSON format." ) return paper_thread def process_thread(thread_data, base_path): processed_data = [] media_set = set() for data in thread_data: cleaned_content = re.sub( r"【\d+†source】", "", data["content"] ) # Remove all source annotations media_list = [] for media in data.get("media", []): if media["path"] and media["path"] not in media_set: media_file_path = os.path.join(base_path, media["path"]) if os.path.isfile(media_file_path): media_list.append(media) media_set.add(media["path"]) processed_data.append({"content": cleaned_content, "media": media_list}) return processed_data def render_markdown(processed_thread): markdown_content = "" for data in processed_thread: markdown_content += data["content"] + "\n" for media in data["media"]: markdown_content += f'\n
\n' markdown_content += f' {media.get(\n' markdown_content += "
\n" markdown_content += "\n---\n\n" return markdown_content def uri_validator(x): try: result = urlparse(x) return all([result.scheme, result.netloc]) except: return False def create_thread( pdf_url_or_path: str, output_path: str, client: openai.OpenAI, assistant_id: str ): # Extract the PDF name from the URL and remove any file extension at the end pdf_name = os.path.splitext(pdf_url_or_path.split("/")[-1])[0] base_path = os.path.join(output_path, pdf_name) results_path = os.path.join(base_path, "results.json") pdf_path = os.path.join(base_path, f"{pdf_name}.pdf") thread_path = os.path.join(base_path, "thread.json") processed_thread_path = os.path.join(base_path, "processed_thread.json") markdown_path = os.path.join(base_path, "processed_thread.md") # Check if base path already exists and there is a results.json # If so, assume we've run this before and just return results if os.path.exists(base_path) and os.path.isfile(results_path): with open(results_path, "r") as f: results = json.load(f) else: os.makedirs(base_path, exist_ok=True) if uri_validator(pdf_url_or_path): pdf_content = requests.get(pdf_url_or_path).content with open(pdf_path, "wb") as f: f.write(pdf_content) elif os.path.isfile(pdf_url_or_path): shutil.copy(pdf_url_or_path, pdf_path) with open(pdf_path, "rb") as f: pdf_content = f.read() else: raise ValueError( f"Invalid input: {pdf_url_or_path}. It should be a valid URL or a file path." ) model = lp.models.Detectron2LayoutModel( config_path="lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config", extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5], label_map={0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}, ) results = process_pdf(pdf_content, model, base_path) # Remove duplicates from results results = [dict(t) for t in set(tuple(d.items()) for d in results)] with open(results_path, "w") as f: json.dump(results, f, indent=2) paper_thread = generate_thread_content(pdf_path, results, client, assistant_id) with open(thread_path, "w") as f: json.dump(paper_thread, f, indent=2) # Process the thread processed_thread = process_thread(paper_thread, base_path) with open(processed_thread_path, "w") as f: json.dump(processed_thread, f, indent=2) # Save processed thread as a markdown file markdown_content = render_markdown(processed_thread) with open(markdown_path, "w") as f: f.write(markdown_content) logger.info(f"Saved all outputs to: {os.path.abspath(base_path)}") return base_path def create_assistant_then_thread( pdf_url_or_path: str, output_path: str, client: openai.OpenAI, assistant_kwargs: Optional[dict] = None, ): if assistant_kwargs is None: assistant_kwargs = {} try: assistant = create_assistant(client, **assistant_kwargs) except Exception: logger.error("Failed to create assistant", exc_info=True) raise try: saved_path = create_thread( pdf_url_or_path, output_path, client, assistant.id, ) except Exception: logger.error("Failed to create thread", exc_info=True) raise finally: try: client.beta.assistants.delete(assistant.id) except Exception: logger.error("Failed to delete assistant", exc_info=True) raise return saved_path if __name__ == "__main__": parser = argparse.ArgumentParser( description="Process a PDF from a URL or a local path." ) parser.add_argument( "url_or_path", type=str, help="The URL or local path of the PDF to process." ) parser.add_argument( "-o", "--output", default="data", help="The output directory to store the results.", ) args = parser.parse_args() client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) create_assistant_then_thread(args.url_or_path, args.output, client)