Spaces:
Sleeping
Sleeping
File size: 4,615 Bytes
2f3144c 40f7360 d85fa29 2f3144c 40f7360 2f3144c 3ef82d2 40f7360 3ef82d2 2f3144c db3dd58 40f7360 2f3144c d85fa29 2f3144c 40f7360 d85fa29 40f7360 d85fa29 40f7360 d85fa29 40f7360 2f3144c d85fa29 2f3144c d85fa29 2f3144c d85fa29 2f3144c 40f7360 2f3144c d85fa29 40f7360 d85fa29 40f7360 d85fa29 2f3144c 40f7360 2be66f7 cacc570 2be66f7 d85fa29 2be66f7 d85fa29 db3dd58 d85fa29 cacc570 40f7360 86bd82f 40f7360 b195407 40f7360 db3dd58 40f7360 db3dd58 40f7360 21264d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import os
import traceback
import spaces
import re
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Byaldi and Qwen2-VL models
rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali") # Byaldi model
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
).to(device) # Move Qwen2-VL to GPU
# Processor for Qwen2-VL
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
# Function for OCR and text extraction
@spaces.GPU(duration=120) # Increased GPU duration to 120 seconds
def ocr_and_extract(image):
try:
# Save the uploaded image temporarily
temp_image_path = "temp_image.jpg"
image.save(temp_image_path)
# Index the image with Byaldi, and force overwrite of the existing index
rag_model.index(
input_path=temp_image_path,
index_name="image_index", # Reuse the same index
store_collection_with_index=False,
overwrite=True # Overwrite the index for every new image
)
# Perform the search query on the indexed image
results = rag_model.search("", k=1)
# Prepare the input for Qwen2-VL
image_data = Image.open(temp_image_path)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_data},
],
}
]
# Process the message and prepare for Qwen2-VL
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
# Move the image inputs and processor outputs to CUDA
inputs = processor(
text=[text_input],
images=image_inputs,
padding=True,
return_tensors="pt",
).to(device)
# Generate the output with Qwen2-VL
generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# Filter out "You are a helpful assistant" and "assistant" labels
filtered_output = [line for line in output_text[0].split("\n") if not any(kw in line.lower() for kw in ["you are a helpful assistant", "assistant", "user", "system"])]
extracted_text = "\n".join(filtered_output).strip()
# Clean up the temporary file
os.remove(temp_image_path)
return extracted_text
except Exception as e:
error_message = str(e)
traceback.print_exc()
return f"Error: {error_message}"
def search_keywords(extracted_text, keywords):
if not extracted_text:
return "No text extracted yet. Please upload an image."
# Highlight matching keywords in the extracted text
highlighted_text = extracted_text
for keyword in keywords.split():
highlighted_text = re.sub(f"({re.escape(keyword)})", r"<mark>\1</mark>", highlighted_text, flags=re.IGNORECASE)
# Return the highlighted text as HTML
return f"<div style='white-space: pre-wrap'>{highlighted_text}</div>"
# Gradio interface for image input and keyword search
with gr.Blocks() as iface:
# Add a title at the top of the interface
gr.HTML("<h1 style='text-align: center'>Byaldi + Qwen2VL</h1>")
# Image upload and text extraction section
with gr.Column():
img_input = gr.Image(type="pil", label="Upload an Image")
extracted_output = gr.Textbox(label="Extracted Text", interactive=False) # Use Textbox to store text
# Functionality to trigger the OCR and extraction
img_button = gr.Button("Extract Text")
img_button.click(fn=ocr_and_extract, inputs=img_input, outputs=extracted_output)
# Keyword search section
with gr.Column():
search_input = gr.Textbox(label="Enter keywords to search")
search_output = gr.HTML(label="Search Results")
# Functionality to search within the extracted text
search_button = gr.Button("Search")
search_button.click(fn=search_keywords, inputs=[extracted_output, search_input], outputs=search_output)
iface.launch() |