Qwen2VL-OCR_CPU / app.py
RufusRubin777's picture
Update app.py
d7b2a5e verified
raw
history blame
3.63 kB
import gradio as gr
from PIL import Image
import json
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import re
# Load models
def load_models():
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32) # float32 for CPU
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
return RAG, model, processor
RAG, model, processor = load_models()
# Function for OCR
def extract_text_from_image(image):
text_query = "Extract all the text in Sanskrit and English from the image."
# Prepare message for Qwen model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_query}
]
}
]
# Process the image
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
).to("cpu") # Use CPU
# Generate text
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=2000)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
extracted_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return extracted_text
# Function for keyword search
def search_keyword_in_text(extracted_text, keyword):
keyword_lower = keyword.lower()
sentences = extracted_text.split('. ')
matched_sentences = []
for sentence in sentences:
if keyword_lower in sentence.lower():
highlighted_sentence = re.sub(f'({re.escape(keyword)})', r'<mark>\1</mark>', sentence, flags=re.IGNORECASE)
matched_sentences.append(highlighted_sentence)
return matched_sentences if matched_sentences else ["No matches found."]
# Gradio App
def app_extract_text(image):
extracted_text = extract_text_from_image(image)
return extracted_text
def app_search_keyword(extracted_text, keyword):
search_results = search_keyword_in_text(extracted_text, keyword)
search_results_str = "<br>".join(search_results)
return search_results_str
title_html = """
<h1><span class="gradient-text" id="text">IIT Roorkee (GOT ASSIGNMENT)</span></h1>
"""
# Gradio Interface
with gr.Blocks() as iface:
gr.HTML(title_html)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload an Image")
extract_button = gr.Button("Extract Text")
extracted_text_output = gr.Textbox(label="Extracted Text")
extract_button.click(app_extract_text, inputs=image_input, outputs=extracted_text_output)
with gr.Column():
keyword_input = gr.Textbox(label="Enter keyword to search in extracted text", placeholder="Keyword")
search_button = gr.Button("Search Keyword")
search_results_output = gr.HTML(label="Search Results")
search_button.click(app_search_keyword, inputs=[extracted_text_output, keyword_input], outputs=search_results_output)
# Launch Gradio App
iface.launch()