Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,775 Bytes
948c8ce 45b6f79 4e1ec1c 4eac50b 1fc115f 4e1ec1c 45b6f79 1796549 45b6f79 4eac50b 4e1ec1c 1796549 832e4f3 948c8ce 73f430f 948c8ce 4e1ec1c 45b6f79 1796549 832e4f3 4e1ec1c 1fc115f 948c8ce be1e49c 4e1ec1c ff86a3f 4e1ec1c 948c8ce 4e1ec1c 948c8ce be1e49c 4e1ec1c 832e4f3 be1e49c 948c8ce 45b6f79 948c8ce 4e1ec1c 832e4f3 45b6f79 832e4f3 45b6f79 832e4f3 45b6f79 832e4f3 45b6f79 832e4f3 45b6f79 832e4f3 948c8ce 45b6f79 be1e49c 948c8ce 45b6f79 948c8ce 1fc115f 4e1ec1c 1fc115f 45b6f79 1fc115f 45b6f79 4e1ec1c 7b6a165 1fc115f 4e1ec1c 832e4f3 45b6f79 832e4f3 e3dcfdd 832e4f3 10a9ffa 832e4f3 10a9ffa 7b6a165 1d82c63 be1e49c 1d82c63 832e4f3 45b6f79 832e4f3 45b6f79 832e4f3 45b6f79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
import subprocess # ๐ฅฒ
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
import spaces
import gradio as gr
import re
import torch
import os
import json
import time
from pydantic import BaseModel
from typing import Tuple
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# ----------------------- Model and Processor Loading ----------------------- #
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
# ----------------------- Pydantic Model Definition ----------------------- #
class GeneralRetrievalQuery(BaseModel):
broad_topical_query: str
broad_topical_explanation: str
specific_detail_query: str
specific_detail_explanation: str
visual_element_query: str
visual_element_explanation: str
def extract_json_with_regex(text):
pattern = r'```(?:json)?\s*(.+?)\s*```'
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[0]
return None
def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
if prompt_name != "general":
raise ValueError("Only 'general' prompt is available in this version")
prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.
Please generate 3 different types of retrieval queries:
1. A broad topical query: This should cover the main subject of the document.
2. A specific detail query: This should focus on a particular fact, figure, or point made in the document.
3. A visual element query: This should reference a chart, graph, image, or other visual component in the document, if present. Don't just reference the name of the visual element but generate a query which this illustration may help answer or be related to.
Important guidelines:
- Ensure the queries are relevant for retrieval tasks, not just describing the page content.
- Frame the queries as if someone is searching for this document, not asking questions about its content.
- Make the queries diverse and representative of different search strategies.
For each query, also provide a brief explanation of why this query would be effective in retrieving this document.
Format your response as a JSON object with the following structure:
{
"broad_topical_query": "Your query here",
"broad_topical_explanation": "Brief explanation",
"specific_detail_query": "Your query here",
"specific_detail_explanation": "Brief explanation",
"visual_element_query": "Your query here",
"visual_element_explanation": "Brief explanation"
}
If there are no relevant visual elements, replace the third query with another specific detail query.
Here is the document image to analyze:
<image>
Generate the queries based on this image and provide the response in the specified JSON format."""
return prompt, GeneralRetrievalQuery
prompt, pydantic_model = get_retrieval_prompt("general")
# ----------------------- Input Preprocessing ----------------------- #
def _prep_data_for_input(image):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
return processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# ----------------------- Output Formatting ----------------------- #
def format_output(data: dict, output_format: str) -> str:
"""
Convert the JSON data into the desired output format.
output_format: "JSON", "Markdown", "Table"
"""
if output_format == "JSON":
# Wrap with code block for better display in Markdown view
return f"```json\n{json.dumps(data, indent=2, ensure_ascii=False)}\n```"
elif output_format == "Markdown":
md_lines = []
for key, value in data.items():
md_lines.append(f"**{key.replace('_', ' ').title()}:** {value}")
return "\n\n".join(md_lines)
elif output_format == "Table":
headers = ["Field", "Content"]
separator = " | ".join(["---"] * len(headers))
rows = [f"| {' | '.join(headers)} |", f"| {separator} |"]
for key, value in data.items():
rows.append(f"| {key.replace('_', ' ').title()} | {value} |")
return "\n".join(rows)
else:
return f"```json\n{json.dumps(data, indent=2, ensure_ascii=False)}\n```"
# ----------------------- Response Generation ----------------------- #
@spaces.GPU
def generate_response(image, output_format: str = "JSON"):
inputs = _prep_data_for_input(image)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=200)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
try:
json_str = extract_json_with_regex(output_text)
if json_str:
parsed = json.loads(json_str)
return format_output(parsed, output_format)
parsed = json.loads(output_text)
return format_output(parsed, output_format)
except Exception:
gr.Warning("Failed to parse JSON from output")
return output_text
# ----------------------- Interface Title and Description (in English) ----------------------- #
title = "Elegant ColPali Query Generator using Qwen2.5-VL"
description = """**ColPali** is a multimodal approach optimized for document retrieval.
This interface uses the [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) model to generate relevant retrieval queries based on a document image.
The queries include:
- **Broad Topical Query:** Covers the main subject of the document.
- **Specific Detail Query:** Focuses on a particular fact, figure, or point from the document.
- **Visual Element Query:** References a visual component (e.g., chart, graph) from the document.
Refer to the examples below to generate queries suitable for your document image.
For more information, please see the associated blog post.
"""
examples = [
"examples/Approche_no_13_1977.pdf_page_22.jpg",
"examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
]
# ----------------------- Custom CSS ----------------------- #
custom_css = """
body {
background: #f7f9fb;
font-family: 'Segoe UI', sans-serif;
color: #333;
}
header {
text-align: center;
padding: 20px;
margin-bottom: 20px;
}
header h1 {
font-size: 3em;
color: #2c3e50;
}
.gradio-container {
padding: 20px;
}
.gr-button {
background-color: #3498db !important;
color: #fff !important;
border: none !important;
padding: 10px 20px !important;
border-radius: 5px !important;
font-size: 1em !important;
}
.gr-button:hover {
background-color: #2980b9 !important;
}
.gr-gallery-item {
border-radius: 10px;
overflow: hidden;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
footer {
text-align: center;
padding: 20px 0;
font-size: 0.9em;
color: #555;
}
"""
# ----------------------- Gradio Interface ----------------------- #
with gr.Blocks(css=custom_css, title=title) as demo:
with gr.Column(variant="panel"):
gr.Markdown(f"<header><h1>{title}</h1></header>")
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem("Query Generation"):
gr.Markdown("### Generate Retrieval Queries from a Document Image")
with gr.Row():
image_input = gr.Image(label="Upload Document Image", type="pil")
with gr.Row():
output_format = gr.Radio(
choices=["JSON", "Markdown", "Table"],
value="JSON",
label="Output Format",
info="Select the desired output format."
)
generate_button = gr.Button("Generate Query")
# ์ถ๋ ฅ ์ปดํฌ๋ํธ๋ฅผ gr.Markdown์ผ๋ก ๋ณ๊ฒฝํ์ฌ Markdown ๋ฐ Table ํ์์ด ์ ๋๋ก ๋ ๋๋ง๋๋๋ก ํจ.
output_text = gr.Markdown(label="Generated Query")
with gr.Accordion("Examples", open=False):
gr.Examples(
label="Query Examples",
examples=[
"examples/Approche_no_13_1977.pdf_page_22.jpg",
"examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
],
inputs=image_input,
)
generate_button.click(
fn=generate_response,
inputs=[image_input, output_format],
outputs=output_text
)
gr.Markdown("<footer>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
demo.launch()
|