OCR_1 / app.py
Adi-yogi's picture
!prompt updation
dec8090 verified
import torch
import gradio as gr
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration,AutoProcessor, AutoModelForCausalLM
from qwen_vl_utils import process_vision_info
import re
# Define the model inference function
def model_inference(image, pattern):
# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load processor and model
min_pixels = 256*28*28
max_pixels = 1080*28*28
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct",min_pixels=min_pixels, max_pixels=max_pixels)
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16).to(device).eval()
# Prepare the input messages
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": "Extract just the text from the image and nothing else"},
],
}
]
# Apply chat template to the text
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Process vision input and prepare the inputs for the model
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
# videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
# Generate output
generated_ids = model.generate(**inputs, max_new_tokens=600)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# Perform regex search for the user-provided pattern
matches = [match for match in re.finditer(pattern, str(output_text))]
# Format output for Gradio
match_info = []
if matches:
for match in matches:
match_info.append(f"Found '{match.group()}' at position {match.start()} to {match.end()}")
else:
match_info.append(f"'{pattern}' not found in the text.")
return output_text, "\n".join(match_info)
# Define the Gradio interface
inputs = [
gr.Image(type="pil", label="Upload Image"), # Upload an image
# gr.Textbox(lines=2, placeholder="Enter text here", label="Enter Text"), # Input text
gr.Textbox(lines=1, placeholder="Enter key to search", label="Pattern") # Input pattern
]
outputs = [
gr.Textbox(label="Extracted Text"), # Display the model's extracted text
gr.Textbox(label="Matches") # Display match info
]
# Launch the Gradio app
gr.Interface(
fn=model_inference,
inputs=inputs,
outputs=outputs,
title="Image and Text Inference",
description="Upload an image, enter text to extract, and provide a regex pattern to search in the extracted text."
).launch(share=True)