# Copyright (c) 2022, Lawrence Livermore National Security, LLC. # All rights reserved. # See the top-level LICENSE and NOTICE files for details. # LLNL-CODE-838964 # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception import torch import gradio as gr from pathlib import Path from torchvision.transforms import ToPILImage, ToTensor tensor_to_image = ToPILImage() image_to_tensor = ToTensor() import sys sys.path.append('DiT_Extractor/') sys.path.append('CrossEncoder/') sys.path.append('UnifiedQA/') import dit_runner import sentence_extractor import cross_encoder import demo_QA from torchvision.transforms import ToPILImage tensor_to_image = ToPILImage() def run_fn(pdf_file_obj, question_text, input_topk): pdf = pdf_file_obj.name print('Running PDF: {0}'.format(pdf)) viz_images = dit_runner.get_dit_preds(pdf, score_threshold=0.5) entity_json = '{0}.json'.format(Path(pdf).name[:-4]) sentence_extractor.get_contexts(entity_json) contexts_json = 'contexts_{0}'.format(entity_json) # contexts_json = 'contexts_2105u2iwiwxh.03011.json' cross_encoder.get_ranked_contexts(contexts_json, question_text) ranked_contexts_json = 'ranked_{0}'.format(contexts_json) # ranked_contexts_json = 'ranked_contexts_2105u2iwiwxh.03011.json' input_topk = int(input_topk) # viz_images = [tensor_to_image(x) for x in torch.randn(4, 3, 256, 256)] qa_results = demo_QA.get_qa_results(contexts_json, ranked_contexts_json, input_topk) history = [('<<< [Retrieval Score: {0:.02f}] >>> {1}'.format(s, c), a) for c, s, a in zip(qa_results['contexts'], qa_results['context_scores'], qa_results['answers'])] # Show in ascending order of score, since results box is already scrolled down. history = history[::-1] return viz_images, contexts_json, ranked_contexts_json, history demo = gr.Blocks() with demo: gr.Markdown("

Detect-Retrieve-Comprehend for Document-Level QA

") gr.Markdown("
This is a supplemental demo for our recent paper, expected to be publically available around October: Detect, Retrieve, Comprehend: A Flexible Framework for Zero-Shot Document-Level Question Answering. In this system, our input is a PDF file with a specific question of interest. The output is a set of most probable answers. There are 4 main components in our deployed pipeline: (1) DiT Layout Analysis (2) Context Extraction (3) Cross-Encoder Retrieval (4) UnifiedQA. See below for example uses with further explanation. Note that demo runtimes may be between 2-8 minutes, since this is currently cpu-based Space.
") with gr.Row(): with gr.Column(): with gr.Row(): input_pdf_file = gr.File(file_count='single', label='PDF File') with gr.Row(): input_question_text = gr.Textbox(label='Question') with gr.Row(): input_k_percent = gr.Slider(minimum=1, maximum=24, step=1, value=8, label='Top K') with gr.Row(): button_run = gr.Button('Run QA on Document') gr.Markdown("

Summary

") with gr.Row(): gr.Markdown(''' - **DiT - Document Image Transformer**: PDF -> converted into a list of images -> each image receives Entity Predictions - Note that using this computer vision approach allows us to ignore things like *page numbers, footnotes, references*, etc - **Paragraph-based Text Extraction**: DiT Bounding Boxes -> Convert into PDF-Space Coordinates -> Text Extraction using PDFMiner6 -> Tokenize & Sentence Split if tokenizer max length is exceeded - **CrossEncoder Context Retrieval**: All Contexts + Question -> Top K Relevant Contexts best suited for answering question - **UnifiedQA**: Most Relevant Contexts + Supplied Question -> Predict Set of Probable Answers ''') with gr.Row(): examples = [ ['examples/1909.00694.pdf', 'What is the seed lexicon?', 5], ['examples/1909.00694.pdf', 'How big is seed lexicon used for training?', 5], ['examples/1810.04805.pdf', 'What is this paper about?', 5], ['examples/1810.04805.pdf', 'What is the model size?', 5], ['examples/2105.03011.pdf', 'How many questions are in this dataset?', 5], ['examples/1909.00694.pdf', 'How are relations used to propagate polarity?', 5], ] gr.Examples(examples=examples, inputs=[input_pdf_file, input_question_text, input_k_percent]) with gr.Column(): with gr.Row(): output_gallery = gr.Gallery(label='DiT Predicted Entities') with gr.Row(): gr.Markdown(''' - The `DiT predicted Entities` output box is scrollable! Scroll to see different page predictions. Note that predictions with confidence scores < 0.5 are not passed forward for text extraction. - If an image is clicked, the output box will switch to a gallery view. To view these outputs in much higher resolution, right-click and choose "open image in new tab" ''') with gr.Row(): output_contexts = gr.File(label='Detected Contexts', interactive=False) output_ranked_contexts = gr.File(label='CrossEncoder Ranked Contexts', interactive=False) with gr.Row(): output_qa_results = gr.Chatbot(color_map=['blue', 'green'], label='UnifiedQA Results').style() gr.Markdown("

Related Work & Code

") gr.Markdown("
DiT (Document Image Transformer) - Arxiv Page | Github Repo
") gr.Markdown("
CrossEncoder - Arxiv Page | Github Repo
") gr.Markdown("
UnifiedQA - Arxiv Page | Github Repo
") button_run.click(fn=run_fn, inputs=[input_pdf_file, input_question_text, input_k_percent], outputs=[output_gallery, output_contexts, output_ranked_contexts, output_qa_results]) demo.launch(enable_queue=True)