File size: 6,682 Bytes
c14d9ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1df1a3
c14d9ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2d8607
b34acaf
c14d9ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be35ce5
c14d9ad
 
 
 
c740df7
c14d9ad
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) 2022, Lawrence Livermore National Security, LLC. 
# All rights reserved.
# See the top-level LICENSE and NOTICE files for details.
# LLNL-CODE-838964

# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception

import torch
import gradio as gr
from pathlib import Path

from torchvision.transforms import ToPILImage, ToTensor
tensor_to_image = ToPILImage()
image_to_tensor = ToTensor()

import sys
sys.path.append('DiT_Extractor/')
sys.path.append('CrossEncoder/')
sys.path.append('UnifiedQA/')

import dit_runner
import sentence_extractor
import cross_encoder
import demo_QA

from torchvision.transforms import ToPILImage
tensor_to_image = ToPILImage()

def run_fn(pdf_file_obj, question_text, input_topk):
    
    pdf = pdf_file_obj.name
    print('Running PDF: {0}'.format(pdf))
    viz_images = dit_runner.get_dit_preds(pdf, score_threshold=0.5)
    entity_json = '{0}.json'.format(Path(pdf).name[:-4])
    
    sentence_extractor.get_contexts(entity_json)
    
    contexts_json = 'contexts_{0}'.format(entity_json)
    # contexts_json = 'contexts_2105u2iwiwxh.03011.json'
    
    cross_encoder.get_ranked_contexts(contexts_json, question_text)
    
    ranked_contexts_json = 'ranked_{0}'.format(contexts_json)
    # ranked_contexts_json = 'ranked_contexts_2105u2iwiwxh.03011.json'
    
    input_topk = int(input_topk)
    
    # viz_images = [tensor_to_image(x) for x in torch.randn(4, 3, 256, 256)]
    
    qa_results = demo_QA.get_qa_results(contexts_json, ranked_contexts_json, input_topk)
    
    history = [('<<< [Retrieval Score: {0:.02f}] >>> {1}'.format(s, c), a) for c, s, a in zip(qa_results['contexts'], qa_results['context_scores'], qa_results['answers'])]
    
    # Show in ascending order of score, since results box is already scrolled down.
    history = history[::-1]
    
    return viz_images, contexts_json, ranked_contexts_json, history

demo = gr.Blocks()

with demo:
    
    gr.Markdown("<h1><center>Detect-Retrieve-Comprehend for Document-Level QA</center></h1>")
    gr.Markdown("<center>This is a supplemental demo for our recent paper, expected to be publically available around October: <b>Detect, Retrieve, Comprehend: A Flexible Framework for Zero-Shot Document-Level Question Answering</b>. In this system, our input is a PDF file with a specific question of interest. The output is a set of most probable answers. There are 4 main components in our deployed pipeline: (1) DiT Layout Analysis (2) Context Extraction (3) Cross-Encoder Retrieval (4) UnifiedQA. See below for example uses with further explanation. Note that demo runtimes may be between 2-8 minutes, since this is currently cpu-based Space.</center>")
    
    with gr.Row():
        with gr.Column():
            with gr.Row():
                input_pdf_file = gr.File(file_count='single', label='PDF File')
            with gr.Row():
                input_question_text = gr.Textbox(label='Question')
            with gr.Row():
                input_k_percent = gr.Slider(minimum=1, maximum=24, step=1, value=8, label='Top K')
            with gr.Row():
                button_run = gr.Button('Run QA on Document')
            
            gr.Markdown("<h3><center>Summary</center></h3>")
            with gr.Row():
                gr.Markdown('''
                - <u>**DiT - Document Image Transformer**</u>: PDF -> converted into a list of images -> each image receives Entity Predictions
                  - Note that using this computer vision approach allows us to ignore things like *page numbers, footnotes, references*, etc
                - <u>**Paragraph-based Text Extraction**</u>: DiT Bounding Boxes -> Convert into PDF-Space Coordinates -> Text Extraction using PDFMiner6 -> Tokenize & Sentence Split if tokenizer max length is exceeded
                - <u>**CrossEncoder Context Retrieval**</u>: All Contexts + Question -> Top K Relevant Contexts best suited for answering question
                - <u>**UnifiedQA**</u>: Most Relevant Contexts + Supplied Question -> Predict Set of Probable Answers
                ''')

        with gr.Column():
            with gr.Row():
                output_gallery = gr.Gallery(label='DiT Predicted Entities')
            with gr.Row():
                gr.Markdown('''
                - The `DiT predicted Entities` output box is scrollable! Scroll to see different page predictions. Note that predictions with confidence scores < 0.5 are not passed forward for text extraction.
                - If an image is clicked, the output box will switch to a gallery view. To view these outputs in much higher resolution, right-click and choose "open image in new tab"
                ''')
            with gr.Row():
                output_contexts = gr.File(label='Detected Contexts', interactive=False)
                output_ranked_contexts = gr.File(label='CrossEncoder Ranked Contexts', interactive=False)
            with gr.Row():
                output_qa_results = gr.Chatbot(color_map=['blue', 'green'], label='UnifiedQA Results').style()
                
    gr.Markdown("<h3><center>Related Work & Code</center></h3>")
    gr.Markdown("<center>DiT (Document Image Transformer) - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>")
    gr.Markdown("<center>CrossEncoder - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>")
    gr.Markdown("<center>UnifiedQA - <a href=https://arxiv.org/abs/2005.00700>Arxiv Page</a> | <a href=https://github.com/allenai/unifiedqa>Github Repo</a></center>")
                
    button_run.click(fn=run_fn, inputs=[input_pdf_file, input_question_text, input_k_percent], outputs=[output_gallery, output_contexts, output_ranked_contexts, output_qa_results])
    
    examples = [
        ['examples/1909.00694.pdf', 'What is the seed lexicon?', 5],
        ['examples/1909.00694.pdf', 'How big is seed lexicon used for training?', 5],
        ['examples/1810.04805.pdf', 'What is this paper about?', 5],
        ['examples/1810.04805.pdf', 'What is the model size?', 5],
        ['examples/2105.03011.pdf', 'How many questions are in this dataset?', 5],
        ['examples/1909.00694.pdf', 'How are relations used to propagate polarity?', 5],

    ]
    gr.Examples(examples=examples,
                inputs=[input_pdf_file, input_question_text, input_k_percent])
    
    # examples = gr.Dataset(components=[input_pdf_file, input_question_text], samples=[[open('examples/1810.04805.pdf', mode='rb'), 'How many parameters are in the model?']])
    demo.launch(enable_queue=True)