File size: 9,775 Bytes
948c8ce
 
 
 
 
 
45b6f79
4e1ec1c
4eac50b
1fc115f
4e1ec1c
 
 
45b6f79
1796549
 
45b6f79
 
 
4eac50b
4e1ec1c
1796549
832e4f3
948c8ce
73f430f
948c8ce
 
 
4e1ec1c
45b6f79
1796549
832e4f3
4e1ec1c
 
 
 
 
 
 
 
1fc115f
 
 
 
 
 
948c8ce
be1e49c
 
 
 
4e1ec1c
 
 
 
 
ff86a3f
4e1ec1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948c8ce
4e1ec1c
948c8ce
be1e49c
4e1ec1c
 
 
832e4f3
be1e49c
948c8ce
 
 
 
45b6f79
948c8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
4e1ec1c
 
832e4f3
45b6f79
 
832e4f3
 
45b6f79
 
832e4f3
 
45b6f79
 
 
 
 
 
 
832e4f3
 
45b6f79
 
 
 
832e4f3
45b6f79
832e4f3
948c8ce
45b6f79
be1e49c
948c8ce
 
 
45b6f79
948c8ce
 
 
 
 
1fc115f
 
4e1ec1c
1fc115f
 
 
45b6f79
1fc115f
45b6f79
4e1ec1c
7b6a165
1fc115f
4e1ec1c
832e4f3
45b6f79
832e4f3
 
e3dcfdd
832e4f3
 
 
 
10a9ffa
832e4f3
 
10a9ffa
7b6a165
1d82c63
be1e49c
 
1d82c63
 
832e4f3
45b6f79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832e4f3
45b6f79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832e4f3
 
45b6f79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import subprocess  # ๐Ÿฅฒ
subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

import spaces
import gradio as gr
import re
import torch
import os
import json
import time
from pydantic import BaseModel
from typing import Tuple
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# ----------------------- Model and Processor Loading ----------------------- #
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-7B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")

# ----------------------- Pydantic Model Definition ----------------------- #
class GeneralRetrievalQuery(BaseModel):
    broad_topical_query: str
    broad_topical_explanation: str
    specific_detail_query: str
    specific_detail_explanation: str
    visual_element_query: str
    visual_element_explanation: str

def extract_json_with_regex(text):
    pattern = r'```(?:json)?\s*(.+?)\s*```'
    matches = re.findall(pattern, text, re.DOTALL)
    if matches:
        return matches[0]
    return None

def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
    if prompt_name != "general":
        raise ValueError("Only 'general' prompt is available in this version")
    prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.

Please generate 3 different types of retrieval queries:

1. A broad topical query: This should cover the main subject of the document.
2. A specific detail query: This should focus on a particular fact, figure, or point made in the document.
3. A visual element query: This should reference a chart, graph, image, or other visual component in the document, if present. Don't just reference the name of the visual element but generate a query which this illustration may help answer or be related to.

Important guidelines:
- Ensure the queries are relevant for retrieval tasks, not just describing the page content.
- Frame the queries as if someone is searching for this document, not asking questions about its content.
- Make the queries diverse and representative of different search strategies.

For each query, also provide a brief explanation of why this query would be effective in retrieving this document.

Format your response as a JSON object with the following structure:

{
  "broad_topical_query": "Your query here",
  "broad_topical_explanation": "Brief explanation",
  "specific_detail_query": "Your query here",
  "specific_detail_explanation": "Brief explanation",
  "visual_element_query": "Your query here",
  "visual_element_explanation": "Brief explanation"
}

If there are no relevant visual elements, replace the third query with another specific detail query.

Here is the document image to analyze:
<image>

Generate the queries based on this image and provide the response in the specified JSON format."""
    return prompt, GeneralRetrievalQuery

prompt, pydantic_model = get_retrieval_prompt("general")

# ----------------------- Input Preprocessing ----------------------- #
def _prep_data_for_input(image):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    return processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

# ----------------------- Output Formatting ----------------------- #
def format_output(data: dict, output_format: str) -> str:
    """
    Convert the JSON data into the desired output format.
    output_format: "JSON", "Markdown", "Table"
    """
    if output_format == "JSON":
        # Wrap with code block for better display in Markdown view
        return f"```json\n{json.dumps(data, indent=2, ensure_ascii=False)}\n```"
    elif output_format == "Markdown":
        md_lines = []
        for key, value in data.items():
            md_lines.append(f"**{key.replace('_', ' ').title()}:** {value}")
        return "\n\n".join(md_lines)
    elif output_format == "Table":
        headers = ["Field", "Content"]
        separator = " | ".join(["---"] * len(headers))
        rows = [f"| {' | '.join(headers)} |", f"| {separator} |"]
        for key, value in data.items():
            rows.append(f"| {key.replace('_', ' ').title()} | {value} |")
        return "\n".join(rows)
    else:
        return f"```json\n{json.dumps(data, indent=2, ensure_ascii=False)}\n```"

# ----------------------- Response Generation ----------------------- #
@spaces.GPU
def generate_response(image, output_format: str = "JSON"):
    inputs = _prep_data_for_input(image)
    inputs = inputs.to("cuda")
    generated_ids = model.generate(**inputs, max_new_tokens=200)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]
    
    try:
        json_str = extract_json_with_regex(output_text)
        if json_str:
            parsed = json.loads(json_str)
            return format_output(parsed, output_format)
        parsed = json.loads(output_text)
        return format_output(parsed, output_format)
    except Exception:
        gr.Warning("Failed to parse JSON from output")
        return output_text

# ----------------------- Interface Title and Description (in English) ----------------------- #
title = "Elegant ColPali Query Generator using Qwen2.5-VL"
description = """**ColPali** is a multimodal approach optimized for document retrieval.
This interface uses the [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) model to generate relevant retrieval queries based on a document image.

The queries include:
- **Broad Topical Query:** Covers the main subject of the document.
- **Specific Detail Query:** Focuses on a particular fact, figure, or point from the document.
- **Visual Element Query:** References a visual component (e.g., chart, graph) from the document.

Refer to the examples below to generate queries suitable for your document image.
For more information, please see the associated blog post.
"""

examples = [
    "examples/Approche_no_13_1977.pdf_page_22.jpg",
    "examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
]

# ----------------------- Custom CSS ----------------------- #
custom_css = """
body {
    background: #f7f9fb;
    font-family: 'Segoe UI', sans-serif;
    color: #333;
}
header {
    text-align: center;
    padding: 20px;
    margin-bottom: 20px;
}
header h1 {
    font-size: 3em;
    color: #2c3e50;
}
.gradio-container {
    padding: 20px;
}
.gr-button {
    background-color: #3498db !important;
    color: #fff !important;
    border: none !important;
    padding: 10px 20px !important;
    border-radius: 5px !important;
    font-size: 1em !important;
}
.gr-button:hover {
    background-color: #2980b9 !important;
}
.gr-gallery-item {
    border-radius: 10px;
    overflow: hidden;
    box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
footer {
    text-align: center;
    padding: 20px 0;
    font-size: 0.9em;
    color: #555;
}
"""

# ----------------------- Gradio Interface ----------------------- #
with gr.Blocks(css=custom_css, title=title) as demo:
    with gr.Column(variant="panel"):
        gr.Markdown(f"<header><h1>{title}</h1></header>")
        gr.Markdown(description)
        
        with gr.Tabs():
            with gr.TabItem("Query Generation"):
                gr.Markdown("### Generate Retrieval Queries from a Document Image")
                with gr.Row():
                    image_input = gr.Image(label="Upload Document Image", type="pil")
                with gr.Row():
                    output_format = gr.Radio(
                        choices=["JSON", "Markdown", "Table"],
                        value="JSON",
                        label="Output Format",
                        info="Select the desired output format."
                    )
                generate_button = gr.Button("Generate Query")
                # ์ถœ๋ ฅ ์ปดํฌ๋„ŒํŠธ๋ฅผ gr.Markdown์œผ๋กœ ๋ณ€๊ฒฝํ•˜์—ฌ Markdown ๋ฐ Table ํ˜•์‹์ด ์ œ๋Œ€๋กœ ๋ Œ๋”๋ง๋˜๋„๋ก ํ•จ.
                output_text = gr.Markdown(label="Generated Query")
                with gr.Accordion("Examples", open=False):
                    gr.Examples(
                        label="Query Examples",
                        examples=[
                            "examples/Approche_no_13_1977.pdf_page_22.jpg",
                            "examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
                        ],
                        inputs=image_input,
                    )
                generate_button.click(
                    fn=generate_response,
                    inputs=[image_input, output_format],
                    outputs=output_text
                )
        
        gr.Markdown("<footer>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")

demo.launch()