test-two / app.py
gauri-sharan's picture
Create app.py
2f3144c verified
raw
history blame
3 kB
import gradio as gr
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import os
import traceback
import spaces # Ensure import for GPU management
# Load the Byaldi and Qwen2-VL models without using .cuda()
rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
)
# Processor for Qwen2-VL
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
@spaces.GPU # Decorate the function for GPU management
def ocr_and_extract(image, text_query):
try:
# Save the uploaded image temporarily
temp_image_path = "temp_image.jpg"
image.save(temp_image_path)
# Index the image with Byaldi
rag_model.index(
input_path=temp_image_path,
index_name="image_index",
store_collection_with_index=False,
overwrite=True
)
# Perform the search query on the indexed image
results = rag_model.search(text_query, k=1)
# Prepare the input for Qwen2-VL
image_data = Image.open(temp_image_path)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_data},
{"type": "text", "text": text_query},
],
}
]
# Process the message and prepare for Qwen2-VL
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = processor(
text=[text_input],
images=image_inputs,
padding=True,
return_tensors="pt",
)
# Move the Qwen2-VL model and inputs to GPU
qwen_model.to("cuda")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Generate the output with Qwen2-VL
generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# Clean up the temporary file
os.remove(temp_image_path)
return output_text[0]
except Exception as e:
error_message = str(e)
traceback.print_exc()
return f"Error: {error_message}"
# Gradio interface for image input
iface = gr.Interface(
fn=ocr_and_extract,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Enter your query (optional)"),
],
outputs="text",
title="Image OCR with Byaldi + Qwen2-VL",
description="Upload an image (JPEG/PNG) containing Hindi and English text for OCR.",
)
# Launch the Gradio app
iface.launch()