File size: 11,858 Bytes
50bee63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import traceback
import gradio as gr
from utils.get_RGB_image import get_RGB_image, is_online_file, steam_online_file
import layoutparser as lp
from PIL import Image
from utils.get_features import get_features
from imagehash import average_hash
from sklearn.metrics.pairwise import cosine_similarity
from utils.visualize_bboxes_on_image import visualize_bboxes_on_image
import fitz

label_map = {0: 'Caption', 1: 'Footnote', 2: 'Formula', 3: 'List-item', 4: 'Page-footer',
             5: 'Page-header', 6: 'Picture', 7: 'Section-header', 8: 'Table', 9: 'Text', 10: 'Title'}
label_names = list(label_map.values())
color_map = {'Caption': '#FF0000', 'Footnote': '#00FF00', 'Formula': '#0000FF', 'List-item': '#FF00FF', 'Page-footer': '#FFFF00',
             'Page-header': '#000000', 'Picture': '#FFFFFF', 'Section-header': '#40E0D0', 'Table': '#F28030', 'Text': '#7F00FF', 'Title': '#C0C0C0'}
cache = {
    'output_document_image_1_hash': None,
    'output_document_image_2_hash': None,
    'document_image_1_features': None,
    'document_image_2_features': None,
    'original_document_image_1': None,
    'original_document_image_2': None
}
pre_message_style = 'border:2px solid pink;padding:4px;border-radius:4px;font-size: 16px;font-weight: 700;background-image: linear-gradient(to bottom right, #e0e619, #ffffff, #FF77CC, rgb(255, 122, 89));'
visualize_bboxes_on_image_kwargs = {
    'label_text_color': 'white',
    'label_fill_color': 'black',
    'label_text_size': 12,
    'label_text_padding': 3,
    'label_rectangle_left_margin': 0,
    'label_rectangle_top_margin': 0
}
vectors_types = ['vectors', 'weighted_vectors',
                 'reduced_vectors', 'reduced_weighted_vectors']


def similarity_fn(model: lp.Detectron2LayoutModel, document_image_1: Image.Image, document_image_2: Image.Image, vectors_type: str):
    message = None
    annotations = {
        'predicted_bboxes': 'predicted_bboxes' if vectors_type in ['vectors', 'weighted_vectors'] else 'reduced_predicted_bboxes',
        'predicted_scores': 'predicted_scores' if vectors_type in ['vectors', 'weighted_vectors'] else 'reduced_predicted_scores',
        'predicted_labels': 'predicted_labels' if vectors_type in ['vectors', 'weighted_vectors'] else 'reduced_predicted_labels',
    }
    show_vectors_type = False
    try:
        if document_image_1 is None or document_image_2 is None:
            message = 'Please load both the documents to compare.'
            gr.Info(message)
        else:
            input_document_image_1_hash = str(average_hash(document_image_1))
            input_document_image_2_hash = str(average_hash(document_image_2))

            if input_document_image_1_hash == cache['output_document_image_1_hash']:
                document_image_1_features = cache['document_image_1_features']
                document_image_1 = cache['original_document_image_1']
            else:
                gr.Info('Generating features for document 1')
                document_image_1_features = get_features(
                    document_image_1, model, label_names)
                cache['document_image_1_features'] = document_image_1_features
                cache['original_document_image_1'] = document_image_1

            if input_document_image_2_hash == cache['output_document_image_2_hash']:
                document_image_2_features = cache['document_image_2_features']
                document_image_2 = cache['original_document_image_2']
            else:
                gr.Info('Generating features for document 2')
                document_image_2_features = get_features(
                    document_image_2, model, label_names)
                cache['document_image_2_features'] = document_image_2_features
                cache['original_document_image_2'] = document_image_2

            gr.Info('Calculating similarity')
            [[similarity]] = cosine_similarity(
                [
                    cache['document_image_1_features'][vectors_type]
                ],
                [
                    cache['document_image_2_features'][vectors_type]
                ])
            message = f'Similarity between the two documents is: {round(similarity, 4)}'
            gr.Info(message)
            gr.Info('Visualizing the bounding boxes for the predicted layout elements on the documents.')
            document_image_1 = visualize_bboxes_on_image(
                image=document_image_1,
                bboxes=cache['document_image_1_features'][annotations['predicted_bboxes']],
                labels=[f'{label}, score:{round(score, 2)}' for label, score in zip(
                    cache['document_image_1_features'][annotations['predicted_labels']],
                    cache['document_image_1_features'][annotations['predicted_scores']])],
                bbox_outline_color=[
                    color_map[label] for label in cache['document_image_1_features'][annotations['predicted_labels']]],
                bbox_fill_color=[
                    (color_map[label], 50) for label in cache['document_image_1_features'][annotations['predicted_labels']]],
                **visualize_bboxes_on_image_kwargs)
            document_image_2 = visualize_bboxes_on_image(
                image=document_image_2,
                bboxes=cache['document_image_2_features'][annotations['predicted_bboxes']],
                labels=[f'{label}, score:{round(score, 2)}' for label, score in zip(
                    cache['document_image_2_features'][annotations['predicted_labels']],
                    cache['document_image_2_features'][annotations['predicted_scores']])],
                bbox_outline_color=[
                    color_map[label] for label in cache['document_image_2_features'][annotations['predicted_labels']]],
                bbox_fill_color=[
                    (color_map[label], 50) for label in cache['document_image_2_features'][annotations['predicted_labels']]],
                **visualize_bboxes_on_image_kwargs)

            cache['output_document_image_1_hash'] = str(
                average_hash(document_image_1))
            cache['output_document_image_2_hash'] = str(
                average_hash(document_image_2))

            show_vectors_type = True
    except Exception as e:
        message = f'<pre style="overflow:auto;">{traceback.format_exc()}</pre>'
        gr.Info(message)
    return [
        gr.HTML(f'<div style="{pre_message_style}">{message}</div>', visible=True),
        document_image_1,
        document_image_2,
        gr.Dropdown(visible=show_vectors_type)
    ]


def load_image(filename, page=0):
    try:
        image = None
        first_error = None
        try:
            if (is_online_file(filename)):
                pixmap = fitz.open("pdf", steam_online_file(filename))[page].get_pixmap()
            else:
                pixmap = fitz.open(filename)[page].get_pixmap()
            image = Image.frombytes("RGB", [pixmap.width, pixmap.height], pixmap.samples)
        except Exception as e:
            first_error = e
            image = get_RGB_image(filename)
        return [
            image,
            None
        ]
    except Exception as second_error:
        error = f'{traceback.format_exc()}\n\nFirst Error:\n{first_error}\n\nSecond Error:\n{second_error}'
        return [None, gr.HTML(value=error, visible=True)]


def preview_url(url, page=0):
    [image, error] = load_image(url, page=page)
    if image:
        return [gr.Tabs(selected=0), image, error]
    else:
        return [gr.Tabs(selected=1), image, error]


def document_view(document_number: int, examples: list[str] = []):
    gr.HTML(value=f'<h4>Load the {"first" if document_number == 1 else "second"} PDF or Document Image</h4>', elem_classes=[
            'center'])
    gr.HTML(value=f'<p>Click the button below to upload Upload PDF or Document Image or cleck the URL tab to add using link.</p>', elem_classes=[
            'center'])
    with gr.Tabs() as document_tabs:
        with gr.Tab("From Image", id=0):
            document = gr.Image(
                type="pil", label=f"Document {document_number}", visible=False, interactive=False, show_download_button=True)
            document_error_message = gr.HTML(
                label="Error Message", visible=False)
            document_preview = gr.UploadButton(
                label="Upload PDF or Document Image",
                file_types=["image", ".pdf"],
                file_count="single")
        with gr.Tab("From URL", id=1):
            document_url = gr.Textbox(
                label=f"Document {document_number} URL",
                info="Paste a Link/URL to PDF or Document Image",
                placeholder="https://datasets-server.huggingface.co/.../image.jpg")
            document_url_error_message = gr.HTML(
                label="Error Message", visible=False)
            document_url_preview = gr.Button(
                value="Preview Link Document", variant="secondary")
        if len(examples) > 0:
            gr.Examples(
                examples=examples,
                inputs=document,
                label='Select any of these test document images')
    document_preview.upload(
        fn=lambda file: load_image(file.name),
        inputs=[document_preview],
        outputs=[document, document_error_message])
    document_url_preview.click(
        fn=preview_url,
        inputs=[document_url],
        outputs=[document_tabs, document, document_url_error_message])
    document.change(
        fn = lambda image: gr.Image(value=image, visible=True) if image else gr.Image(value=None, visible=False),
        inputs = [document],
        outputs = [document])
    return document


def app(*, model_path:str, config_path:str, examples: list[str], debug=False):
    model: lp.Detectron2LayoutModel = lp.Detectron2LayoutModel(
        config_path=config_path,
        model_path=model_path,
        label_map=label_map)
    title = 'Document Similarity Search Using Visual Layout Features'
    description = f"<h2>{title}<h2>"
    css = '''
    image { max-height="86vh" !important; }
    .center { display: flex; flex: 1 1 auto; align-items: center; align-content: center; justify-content: center; justify-items: center; }
    .hr { width: 100%; display: block; padding: 0; margin: 0; background: gray; height: 4px;  border: none; }
    '''
    with gr.Blocks(title=title, css=css) as interface:
        with gr.Row():
            gr.HTML(value=description, elem_classes=['center'])
        with gr.Row(equal_height=False):
            with gr.Column():
                document_1_image = document_view(1, examples)
            with gr.Column():
                document_2_image = document_view(2, examples)
        gr.HTML('<hr/>', elem_classes=['hr'])
        with gr.Row(elem_classes=['center']):
            with gr.Column():
                submit = gr.Button(value="Get Similarity", variant="primary")
            with gr.Column():
                vectors_type = gr.Dropdown(
                    choices=vectors_types,
                    value=vectors_types[0],
                    visible=False,
                    label="Vectors Type",
                    info="Select the Vectors Type to use for Similarity Calculation")
                similarity_output = gr.HTML(
                    label="Similarity Score", visible=False)
        kwargs = {
            'fn': lambda document_1_image, document_2_image, vectors_type: similarity_fn(
                model,
                document_1_image,
                document_2_image,
                vectors_type),
            'inputs': [document_1_image, document_2_image, vectors_type],
            'outputs': [similarity_output, document_1_image, document_2_image, vectors_type]
        }
        submit.click(**kwargs)
        vectors_type.change(**kwargs)
    return interface.launch(debug=debug)