Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,722 Bytes
3f53d8e 4e1ec1c 4eac50b 1796549 4e1ec1c 1796549 4eac50b 4e1ec1c 1796549 3f53d8e 4e1ec1c 4011b8e 4e1ec1c 1796549 4e1ec1c 1796549 4e1ec1c 1796549 4e1ec1c 1796549 4e1ec1c 3f53d8e 4e1ec1c 1796549 4e1ec1c 3f53d8e 4e1ec1c 7b6a165 4e1ec1c 1796549 e3dcfdd 6c3e99a e3dcfdd 6c3e99a e3dcfdd 10a9ffa e3dcfdd 10a9ffa 7b6a165 4eac50b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import subprocess # 🥲
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
import spaces
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import os
import json
from pydantic import BaseModel
from typing import Tuple
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
class GeneralRetrievalQuery(BaseModel):
broad_topical_query: str
broad_topical_explanation: str
specific_detail_query: str
specific_detail_explanation: str
visual_element_query: str
visual_element_explanation: str
def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
if prompt_name != "general":
raise ValueError("Only 'general' prompt is available in this version")
prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.
Please generate 3 different types of retrieval queries:
1. A broad topical query: This should cover the main subject of the document.
2. A specific detail query: This should focus on a particular fact, figure, or point made in the document.
3. A visual element query: This should reference a chart, graph, image, or other visual component in the document, if present.
Important guidelines:
- Ensure the queries are relevant for retrieval tasks, not just describing the page content.
- Frame the queries as if someone is searching for this document, not asking questions about its content.
- Make the queries diverse and representative of different search strategies.
For each query, also provide a brief explanation of why this query would be effective in retrieving this document.
Format your response as a JSON object with the following structure:
{
"broad_topical_query": "Your query here",
"broad_topical_explanation": "Brief explanation",
"specific_detail_query": "Your query here",
"specific_detail_explanation": "Brief explanation",
"visual_element_query": "Your query here",
"visual_element_explanation": "Brief explanation"
}
If there are no relevant visual elements, replace the third query with another specific detail query.
Here is the document image to analyze:
<image>
Generate the queries based on this image and provide the response in the specified JSON format."""
return prompt, GeneralRetrievalQuery
# defined like this so we can later add more prompting options
prompt, pydantic_model = get_retrieval_prompt("general")
def _prep_data_for_input(image):
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
return processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
@spaces.GPU
def generate_response(image):
inputs = _prep_data_for_input(image)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=200)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
try:
return json.loads(output_text[0])
except Exception:
gr.Warning("Failed to parse JSON from output")
return {}
title = "ColPali fine-tuning Query Generator"
description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach.
To train or fine-tune a ColPali model, we need a dataset of image-text pairs which represent the document images and the relevant text queries which those documents should match.
To make the ColPali models work even better we might want a dataset of query/image document pairs related to our domain or task.
One way in which we might go about generating such a dataset is to use an VLM to generate synthetic queries for us.
This space uses the [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) to generate queries for a document, based on an input document image.
This [blog post](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html) gives an overview of how you can use this kind of approach to generate a full dataset for fine-tuning ColPali models.
If you want to convert a PDF(s) to a dataset of page images you can try out the [ PDFs to Page Images Converter](https://huggingface.co/spaces/Dataset-Creation-Tools/pdf-to-page-images-dataset) Space.
"""
demo = gr.Interface(
fn=generate_response,
inputs=gr.Image(type="pil"),
outputs=gr.Json(),
title=title,
description=description,
)
demo.launch()
|