Qwen2VL-OCR_CPU / app.py
RufusRubin777's picture
Update app.py
b295a57 verified
raw
history blame
4.27 kB
import gradio as gr
from PIL import Image
import json
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import re
# Load models
def load_models():
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32) # float32 for CPU
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
return RAG, model, processor
RAG, model, processor = load_models()
# Global variable to store extracted text
extracted_text_global = ""
# Function for OCR extraction
def extract_text(image):
global extracted_text_global
text_query = "Extract all the text in Sanskrit and English from the image."
# Prepare message for Qwen model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_query},
],
}
]
# Process the image
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to("cpu") # Use CPU
# Generate text
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=2000)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
extracted_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
# Store extracted text in global variable
extracted_text_global = extracted_text
return extracted_text
# Function for keyword search within extracted text
def search_keyword(keyword):
global extracted_text_global
if not extracted_text_global:
return "No extracted text available. Please extract text first.", "No matches found."
keyword_lower = keyword.lower()
sentences = extracted_text_global.split('. ')
matched_sentences = []
# Perform keyword search with highlighting
for sentence in sentences:
if keyword_lower in sentence.lower():
highlighted_sentence = re.sub(
f'({re.escape(keyword)})',
r'<mark>\1</mark>', # Highlight the matched keyword
sentence,
flags=re.IGNORECASE
)
matched_sentences.append(highlighted_sentence)
search_results_str = "<br>".join(matched_sentences) if matched_sentences else "No matches found."
return extracted_text_global, search_results_str
# Gradio App
def app_extract(image):
extracted_text = extract_text(image)
return extracted_text
def app_search(keyword):
extracted_text, search_results = search_keyword(keyword)
return extracted_text, search_results
# Gradio Interface with two buttons
iface = gr.Interface(
fn=[app_extract, app_search],
inputs=[
gr.Image(type="pil", label="Upload an Image"),
gr.Textbox(label="Enter keyword to search in extracted text", placeholder="Keyword")
],
outputs=[
gr.Textbox(label="Extracted Text"),
gr.HTML(label="Search Results"),
],
title="OCR and Keyword Search in Images",
live=False,
description="First, extract the text from an image, then search for a keyword in the extracted text.",
layout="vertical",
allow_flagging="never"
)
# Create separate buttons
extract_button = gr.Button("Extract Text")
search_button = gr.Button("Search Keyword")
# Link buttons to their respective functions
extract_button.click(fn=app_extract, inputs=[gr.Image(type="pil")], outputs=[gr.Textbox(label="Extracted Text")])
search_button.click(fn=app_search, inputs=[gr.Textbox(label="Enter keyword")], outputs=[gr.Textbox(label="Extracted Text"), gr.HTML(label="Search Results")])
# Launch Gradio App
iface.launch()